MasonCrinr commited on
Commit
8026e91
1 Parent(s): 1709f73

Upload 331 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +1 -0
  3. .gitignore +133 -0
  4. LICENSE +22 -0
  5. MANIFEST.in +2 -0
  6. README.md +284 -0
  7. apex/.gitignore +5 -0
  8. apex/.nojekyll +0 -0
  9. apex/LICENSE +11 -0
  10. apex/README.md +99 -0
  11. apex/apex.patch +42 -0
  12. apex/apex/RNN/README.md +1 -0
  13. apex/apex/RNN/RNNBackend.py +365 -0
  14. apex/apex/RNN/__init__.py +3 -0
  15. apex/apex/RNN/cells.py +84 -0
  16. apex/apex/RNN/models.py +54 -0
  17. apex/apex/__init__.py +13 -0
  18. apex/apex/amp/README.md +72 -0
  19. apex/apex/amp/__init__.py +5 -0
  20. apex/apex/amp/__version__.py +2 -0
  21. apex/apex/amp/_amp_state.py +70 -0
  22. apex/apex/amp/_initialize.py +268 -0
  23. apex/apex/amp/_process_optimizer.py +411 -0
  24. apex/apex/amp/amp.py +177 -0
  25. apex/apex/amp/compat.py +42 -0
  26. apex/apex/amp/frontend.py +399 -0
  27. apex/apex/amp/handle.py +280 -0
  28. apex/apex/amp/lists/__init__.py +0 -0
  29. apex/apex/amp/lists/functional_overrides.py +77 -0
  30. apex/apex/amp/lists/tensor_overrides.py +63 -0
  31. apex/apex/amp/lists/torch_overrides.py +103 -0
  32. apex/apex/amp/opt.py +103 -0
  33. apex/apex/amp/rnn_compat.py +53 -0
  34. apex/apex/amp/scaler.py +210 -0
  35. apex/apex/amp/utils.py +213 -0
  36. apex/apex/amp/wrap.py +276 -0
  37. apex/apex/fp16_utils/README.md +16 -0
  38. apex/apex/fp16_utils/__init__.py +16 -0
  39. apex/apex/fp16_utils/fp16_optimizer.py +643 -0
  40. apex/apex/fp16_utils/fp16util.py +187 -0
  41. apex/apex/fp16_utils/loss_scaler.py +186 -0
  42. apex/apex/multi_tensor_apply/__init__.py +4 -0
  43. apex/apex/multi_tensor_apply/multi_tensor_apply.py +30 -0
  44. apex/apex/normalization/__init__.py +1 -0
  45. apex/apex/normalization/fused_layer_norm.py +160 -0
  46. apex/apex/optimizers/__init__.py +2 -0
  47. apex/apex/optimizers/fp16_optimizer.py +274 -0
  48. apex/apex/optimizers/fused_adam.py +147 -0
  49. apex/apex/parallel/LARC.py +97 -0
  50. apex/apex/parallel/README.md +66 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tensorboardX/screenshots/image.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Global
2
+ .DS_Store
3
+ .idea
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99
+ __pypackages__/
100
+
101
+ # Celery stuff
102
+ celerybeat-schedule
103
+ celerybeat.pid
104
+
105
+ # SageMath parsed files
106
+ *.sage.py
107
+
108
+ # Environments
109
+ .env
110
+ .venv
111
+ env/
112
+ venv/
113
+ ENV/
114
+ env.bak/
115
+ venv.bak/
116
+
117
+ # Spyder project settings
118
+ .spyderproject
119
+ .spyproject
120
+
121
+ # Rope project settings
122
+ .ropeproject
123
+
124
+ # mkdocs documentation
125
+ /site
126
+
127
+ # mypy
128
+ .mypy_cache/
129
+ .dmypy.json
130
+ dmypy.json
131
+
132
+ # Pyre type checker
133
+ .pyre/
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Noncommercial Use License
2
+
3
+ Software Copyright (c) 2020 OpenAI
4
+
5
+ We don’t claim ownership of the content you create with Jukebox.
6
+ We only ask that you use Jukebox responsibly and clearly indicate your content was created using OpenAI’s Jukebox.
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
9
+ documentation files (the "Software"), to deal in the Software, including without limitation the rights to use, copy,
10
+ modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the
11
+ Software is furnished to do so, subject to the following conditions:
12
+
13
+ No portion of the Software, nor any content created with the Software, may be used for commercial purposes.
14
+
15
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16
+
17
+ The above copyright notice and this permission notice need not be included with content created by the Software.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
20
+ WARRANTIES OF MERCHANTABILITY,FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
21
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
22
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ recursive-include jukebox *.py
2
+ recursive-include jukebox *.txt
README.md ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **Status:** Archive (code is provided as-is, no updates expected)
2
+
3
+ # Jukebox
4
+ Code for "Jukebox: A Generative Model for Music"
5
+
6
+ [Paper](https://arxiv.org/abs/2005.00341)
7
+ [Blog](https://openai.com/blog/jukebox)
8
+ [Explorer](http://jukebox.openai.com/)
9
+ [Colab](https://colab.research.google.com/github/openai/jukebox/blob/master/jukebox/Interacting_with_Jukebox.ipynb)
10
+
11
+ # Install
12
+ Install the conda package manager from https://docs.conda.io/en/latest/miniconda.html
13
+
14
+ ```
15
+ # Required: Sampling
16
+ conda create --name jukebox python=3.7.5
17
+ conda activate jukebox
18
+ conda install mpi4py=3.0.3 # if this fails, try: pip install mpi4py==3.0.3
19
+ conda install pytorch=1.4 torchvision=0.5 cudatoolkit=10.0 -c pytorch
20
+ git clone https://github.com/openai/jukebox.git
21
+ cd jukebox
22
+ pip install -r requirements.txt
23
+ pip install -e .
24
+
25
+ # Required: Training
26
+ conda install av=7.0.01 -c conda-forge
27
+ pip install ./tensorboardX
28
+
29
+ # Optional: Apex for faster training with fused_adam
30
+ conda install pytorch=1.1 torchvision=0.3 cudatoolkit=10.0 -c pytorch
31
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex
32
+ ```
33
+
34
+ # Sampling
35
+ ## Sampling from scratch
36
+ To sample normally, run the following command. Model can be `5b`, `5b_lyrics`, `1b_lyrics`
37
+ ```
38
+ python jukebox/sample.py --model=5b_lyrics --name=sample_5b --levels=3 --sample_length_in_seconds=20 \
39
+ --total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
40
+ ```
41
+ ```
42
+ python jukebox/sample.py --model=1b_lyrics --name=sample_1b --levels=3 --sample_length_in_seconds=20 \
43
+ --total_sample_length_in_seconds=180 --sr=44100 --n_samples=16 --hop_fraction=0.5,0.5,0.125
44
+ ```
45
+ The above generates the first `sample_length_in_seconds` seconds of audio from a song of total length `total_sample_length_in_seconds`.
46
+ To use multiple GPU's, launch the above scripts as `mpiexec -n {ngpus} python jukebox/sample.py ...` so they use `{ngpus}`
47
+
48
+ The samples decoded from each level are stored in `{name}/level_{level}`.
49
+ You can also view the samples as an html with the aligned lyrics under `{name}/level_{level}/index.html`.
50
+ Run `python -m http.server` and open the html through the server to see the lyrics animate as the song plays.
51
+ A summary of all sampling data including zs, x, labels and sampling_kwargs is stored in `{name}/level_{level}/data.pth.tar`.
52
+
53
+ The hps are for a V100 GPU with 16 GB GPU memory. The `1b_lyrics`, `5b`, and `5b_lyrics` top-level priors take up
54
+ 3.8 GB, 10.3 GB, and 11.5 GB, respectively. The peak memory usage to store transformer key, value cache is about 400 MB
55
+ for `1b_lyrics` and 1 GB for `5b_lyrics` per sample. If you are having trouble with CUDA OOM issues, try `1b_lyrics` or
56
+ decrease `max_batch_size` in sample.py, and `--n_samples` in the script call.
57
+
58
+ On a V100, it takes about 3 hrs to fully sample 20 seconds of music. Since this is a long time, it is recommended to use `n_samples > 1` so you can generate as many samples as possible in parallel. The 1B lyrics and upsamplers can process 16 samples at a time, while 5B can fit only up to 3. Since the vast majority of time is spent on upsampling, we recommend using a multiple of 3 less than 16 like `--n_samples 15` for `5b_lyrics`. This will make the top-level generate samples in groups of three while upsampling is done in one pass.
59
+
60
+ To continue sampling from already generated codes for a longer duration, you can run
61
+ ```
62
+ python jukebox/sample.py --model=5b_lyrics --name=sample_5b --levels=3 --mode=continue \
63
+ --codes_file=sample_5b/level_0/data.pth.tar --sample_length_in_seconds=40 --total_sample_length_in_seconds=180 \
64
+ --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
65
+ ```
66
+ Here, we take the 20 seconds samples saved from the first sampling run at `sample_5b/level_0/data.pth.tar` and continue by adding 20 more seconds.
67
+
68
+ You could also continue directly from the level 2 saved outputs, just pass `--codes_file=sample_5b/level_2/data.pth.tar`.
69
+ Note this will upsample the full 40 seconds song at the end.
70
+
71
+ If you stopped sampling at only the first level and want to upsample the saved codes, you can run
72
+ ```
73
+ python jukebox/sample.py --model=5b_lyrics --name=sample_5b --levels=3 --mode=upsample \
74
+ --codes_file=sample_5b/level_2/data.pth.tar --sample_length_in_seconds=20 --total_sample_length_in_seconds=180 \
75
+ --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
76
+ ```
77
+ Here, we take the 20 seconds samples saved from the first sampling run at `sample_5b/level_2/data.pth.tar` and upsample the lower two levels.
78
+
79
+ ## Prompt with your own music
80
+ If you want to prompt the model with your own creative piece or any other music, first save them as wave files and run
81
+ ```
82
+ python jukebox/sample.py --model=5b_lyrics --name=sample_5b_prompted --levels=3 --mode=primed \
83
+ --audio_file=path/to/recording.wav,awesome-mix.wav,fav-song.wav,etc.wav --prompt_length_in_seconds=12 \
84
+ --sample_length_in_seconds=20 --total_sample_length_in_seconds=180 --sr=44100 --n_samples=6 --hop_fraction=0.5,0.5,0.125
85
+ ```
86
+ This will load the four files, tile them to fill up to `n_samples` batch size, and prime the model with the first `prompt_length_in_seconds` seconds.
87
+
88
+ # Training
89
+ ## VQVAE
90
+ To train a small vqvae, run
91
+ ```
92
+ mpiexec -n {ngpus} python jukebox/train.py --hps=small_vqvae --name=small_vqvae --sample_length=262144 --bs=4 \
93
+ --audio_files_dir={audio_files_dir} --labels=False --train --aug_shift --aug_blend
94
+ ```
95
+ Here, `{audio_files_dir}` is the directory in which you can put the audio files for your dataset, and `{ngpus}` is number of GPU's you want to use to train.
96
+ The above trains a two-level VQ-VAE with `downs_t = (5,3)`, and `strides_t = (2, 2)` meaning we downsample the audio by `2**5 = 32` to get the first level of codes, and `2**8 = 256` to get the second level codes.
97
+ Checkpoints are stored in the `logs` folder. You can monitor the training by running Tensorboard
98
+ ```
99
+ tensorboard --logdir logs
100
+ ```
101
+
102
+ ## Prior
103
+ ### Train prior or upsamplers
104
+ Once the VQ-VAE is trained, we can restore it from its saved checkpoint and train priors on the learnt codes.
105
+ To train the top-level prior, we can run
106
+
107
+ ```
108
+ mpiexec -n {ngpus} python jukebox/train.py --hps=small_vqvae,small_prior,all_fp16,cpu_ema --name=small_prior \
109
+ --sample_length=2097152 --bs=4 --audio_files_dir={audio_files_dir} --labels=False --train --test --aug_shift --aug_blend \
110
+ --restore_vqvae=logs/small_vqvae/checkpoint_latest.pth.tar --prior --levels=2 --level=1 --weight_decay=0.01 --save_iters=1000
111
+ ```
112
+
113
+ To train the upsampler, we can run
114
+ ```
115
+ mpiexec -n {ngpus} python jukebox/train.py --hps=small_vqvae,small_upsampler,all_fp16,cpu_ema --name=small_upsampler \
116
+ --sample_length=262144 --bs=4 --audio_files_dir={audio_files_dir} --labels=False --train --test --aug_shift --aug_blend \
117
+ --restore_vqvae=logs/small_vqvae/checkpoint_latest.pth.tar --prior --levels=2 --level=0 --weight_decay=0.01 --save_iters=1000
118
+ ```
119
+ We pass `sample_length = n_ctx * downsample_of_level` so that after downsampling the tokens match the n_ctx of the prior hps.
120
+ Here, `n_ctx = 8192` and `downsamples = (32, 256)`, giving `sample_lengths = (8192 * 32, 8192 * 256) = (65536, 2097152)` respectively for the bottom and top level.
121
+
122
+ ### Learning rate annealing
123
+ To get the best sample quality anneal the learning rate to 0 near the end of training. To do so, continue training from the latest
124
+ checkpoint and run with
125
+ ```
126
+ --restore_prior="path/to/checkpoint" --lr_use_linear_decay --lr_start_linear_decay={already_trained_steps} --lr_decay={decay_steps_as_needed}
127
+ ```
128
+
129
+ ### Reuse pre-trained VQ-VAE and train top-level prior on new dataset from scratch.
130
+ #### Train without labels
131
+ Our pre-trained VQ-VAE can produce compressed codes for a wide variety of genres of music, and the pre-trained upsamplers
132
+ can upsample them back to audio that sound very similar to the original audio.
133
+ To re-use these for a new dataset of your choice, you can retrain just the top-level
134
+
135
+ To train top-level on a new dataset, run
136
+ ```
137
+ mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior \
138
+ --sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
139
+ --labels=False --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
140
+ ```
141
+ Training the `small_prior` with a batch size of 2, 4, and 8 requires 6.7 GB, 9.3 GB, and 15.8 GB of GPU memory, respectively. A few days to a week of training typically yields reasonable samples when the dataset is homogeneous (e.g. all piano pieces, songs of the same style, etc).
142
+
143
+ Near the end of training, follow [this](#learning-rate-annealing) to anneal the learning rate to 0
144
+
145
+ #### Sample from new model
146
+ You can then run sample.py with the top-level of our models replaced by your new model. To do so,
147
+ - Add an entry `my_model=("vqvae", "upsampler_level_0", "upsampler_level_1", "small_prior")` in `MODELS` in `make_models.py`.
148
+ - Update the `small_prior` dictionary in `hparams.py` to include `restore_prior='path/to/checkpoint'`. If you
149
+ you changed any hps directly in the command line script (eg:`heads`), make sure to update them in the dictionary too so
150
+ that `make_models` restores our checkpoint correctly.
151
+ - Run sample.py as outlined in the sampling section, but now with `--model=my_model`
152
+
153
+ For example, let's say we trained `small_vqvae`, `small_prior`, and `small_upsampler` under `/path/to/jukebox/logs`. In `make_models.py`, we are going to declare a tuple of the new models as `my_model`.
154
+ ```
155
+ MODELS = {
156
+ '5b': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b"),
157
+ '5b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b_lyrics"),
158
+ '1b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_1b_lyrics"),
159
+ 'my_model': ("my_small_vqvae", "my_small_upsampler", "my_small_prior"),
160
+ }
161
+ ```
162
+
163
+ Next, in `hparams.py`, we add them to the registry with the corresponding `restore_`paths and any other command line options used during training. Another important note is that for top-level priors with lyric conditioning, we have to locate a self-attention layer that shows alignment between the lyric and music tokens. Look for layers where `prior.prior.transformer._attn_mods[layer].attn_func` is either 6 or 7. If your model is starting to sing along lyrics, it means some layer, head pair has learned alignment. Congrats!
164
+ ```
165
+ my_small_vqvae = Hyperparams(
166
+ restore_vqvae='/path/to/jukebox/logs/small_vqvae/checkpoint_some_step.pth.tar',
167
+ )
168
+ my_small_vqvae.update(small_vqvae)
169
+ HPARAMS_REGISTRY["my_small_vqvae"] = my_small_vqvae
170
+
171
+ my_small_prior = Hyperparams(
172
+ restore_prior='/path/to/jukebox/logs/small_prior/checkpoint_latest.pth.tar',
173
+ level=1,
174
+ labels=False,
175
+ # TODO For the two lines below, if `--labels` was used and the model is
176
+ # trained with lyrics, find and enter the layer, head pair that has learned
177
+ # alignment.
178
+ alignment_layer=47,
179
+ alignment_head=0,
180
+ )
181
+ my_small_prior.update(small_prior)
182
+ HPARAMS_REGISTRY["my_small_prior"] = my_small_prior
183
+
184
+ my_small_upsampler = Hyperparams(
185
+ restore_prior='/path/to/jukebox/logs/small_upsampler/checkpoint_latest.pth.tar',
186
+ level=0,
187
+ labels=False,
188
+ )
189
+ my_small_upsampler.update(small_upsampler)
190
+ HPARAMS_REGISTRY["my_small_upsampler"] = my_small_upsampler
191
+ ```
192
+
193
+ #### Train with labels
194
+ To train with you own metadata for your audio files, implement `get_metadata` in `data/files_dataset.py` to return the
195
+ `artist`, `genre` and `lyrics` for a given audio file. For now, you can pass `''` for lyrics to not use any lyrics.
196
+
197
+ For training with labels, we'll use `small_labelled_prior` in `hparams.py`, and we set `labels=True,labels_v3=True`.
198
+ We use 2 kinds of labels information:
199
+ - Artist/Genre:
200
+ - For each file, we return an artist_id and a list of genre_ids. The reason we have a list and not a single genre_id
201
+ is that in v2, we split genres like `blues_rock` into a bag of words `[blues, rock]`, and we pass atmost
202
+ `max_bow_genre_size` of those, in `v3` we consider it as a single word and just set `max_bow_genre_size=1`.
203
+ - Update the `v3_artist_ids` and `v3_genre_ids` to use ids from your new dataset.
204
+ - In `small_labelled_prior`, set the hps `y_bins = (number_of_genres, number_of_artists)` and `max_bow_genre_size=1`.
205
+ - Timing:
206
+ - For each chunk of audio, we return the `total_length` of the song, the `offset` the current audio chunk is at and
207
+ the `sample_length` of the audio chunk. We have three timing embeddings: total_length, our current position, and our
208
+ current position as a fraction of the total length, and we divide the range of these values into `t_bins` discrete bins.
209
+ - In `small_labelled_prior`, set the hps `min_duration` and `max_duration` to be the shortest/longest duration of audio
210
+ files you want for your dataset, and `t_bins` for how many bins you want to discretize timing information into. Note
211
+ `min_duration * sr` needs to be at least `sample_length` to have an audio chunk in it.
212
+
213
+ After these modifications, to train a top-level with labels, run
214
+ ```
215
+ mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_labelled_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_prior_labels \
216
+ --sample_length=1048576 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
217
+ --labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
218
+ ```
219
+
220
+ For sampling, follow same instructions as [above](#sample-from-new-model) but use `small_labelled_prior` instead of `small_prior`.
221
+
222
+ #### Train with lyrics
223
+ To train in addition with lyrics, update `get_metadata` in `data/files_dataset.py` to return `lyrics` too.
224
+ For training with lyrics, we'll use `small_single_enc_dec_prior` in `hparams.py`.
225
+ - Lyrics:
226
+ - For each file, we linearly align the lyric characters to the audio, find the position in lyric that corresponds to
227
+ the midpoint of our audio chunk, and pass a window of `n_tokens` lyric characters centred around that.
228
+ - In `small_single_enc_dec_prior`, set the hps `use_tokens=True` and `n_tokens` to be the number of lyric characters
229
+ to use for an audio chunk. Set it according to the `sample_length` you're training on so that its large enough that
230
+ the lyrics for an audio chunk are almost always found inside a window of that size.
231
+ - If you use a non-English vocabulary, update `text_processor.py` with your new vocab and set
232
+ `n_vocab = number of characters in vocabulary` accordingly in `small_single_enc_dec_prior`. In v2, we had a `n_vocab=80`
233
+ and in v3 we missed `+` and so `n_vocab=79` of characters.
234
+
235
+ After these modifications, to train a top-level with labels and lyrics, run
236
+ ```
237
+ mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,small_single_enc_dec_prior,all_fp16,cpu_ema --name=pretrained_vqvae_small_single_enc_dec_prior_labels \
238
+ --sample_length=786432 --bs=4 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
239
+ --labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
240
+ ```
241
+ To simplify hps choices, here we used a `single_enc_dec` model like the `1b_lyrics` model that combines both encoder and
242
+ decoder of the transformer into a single model. We do so by merging the lyric vocab and vq-vae vocab into a single
243
+ larger vocab, and flattening the lyric tokens and the vq-vae codes into a single sequence of length `n_ctx + n_tokens`.
244
+ This uses `attn_order=12` which includes `prime_attention` layers with keys/values from lyrics and queries from audio.
245
+ If you instead want to use a model with the usual encoder-decoder style transformer, use `small_sep_enc_dec_prior`.
246
+
247
+ For sampling, follow same instructions as [above](#sample-from-new-model) but use `small_single_enc_dec_prior` instead of
248
+ `small_prior`. To also get the alignment between lyrics and samples in the saved html, you'll need to set `alignment_layer`
249
+ and `alignment_head` in `small_single_enc_dec_prior`. To find which layer/head is best to use, run a forward pass on a training example,
250
+ save the attention weight tensors for all prime_attention layers, and pick the (layer, head) which has the best linear alignment
251
+ pattern between the lyrics keys and music queries.
252
+
253
+ ### Fine-tune pre-trained top-level prior to new style(s)
254
+ Previously, we showed how to train a small top-level prior from scratch. Assuming you have a GPU with at least 15 GB of memory and support for fp16, you could fine-tune from our pre-trained 1B top-level prior. Here are the steps:
255
+
256
+ - Support `--labels=True` by implementing `get_metadata` in `jukebox/data/files_dataset.py` for your dataset.
257
+ - Add new entries in `jukebox/data/ids`. We recommend replacing existing mappings (e.g. rename `"unknown"`, etc with styles of your choice). This uses the pre-trained style vectors as initialization and could potentially save some compute.
258
+
259
+ After these modifications, run
260
+ ```
261
+ mpiexec -n {ngpus} python jukebox/train.py --hps=vqvae,prior_1b_lyrics,all_fp16,cpu_ema --name=finetuned \
262
+ --sample_length=1048576 --bs=1 --aug_shift --aug_blend --audio_files_dir={audio_files_dir} \
263
+ --labels=True --train --test --prior --levels=3 --level=2 --weight_decay=0.01 --save_iters=1000
264
+ ```
265
+ To get the best sample quality, it is recommended to anneal the learning rate in the end. Training the 5B top-level requires GPipe which is not supported in this release.
266
+
267
+ # Citation
268
+
269
+ Please cite using the following bibtex entry:
270
+
271
+ ```
272
+ @article{dhariwal2020jukebox,
273
+ title={Jukebox: A Generative Model for Music},
274
+ author={Dhariwal, Prafulla and Jun, Heewoo and Payne, Christine and Kim, Jong Wook and Radford, Alec and Sutskever, Ilya},
275
+ journal={arXiv preprint arXiv:2005.00341},
276
+ year={2020}
277
+ }
278
+ ```
279
+
280
+ # License
281
+ [Noncommercial Use License](./LICENSE)
282
+
283
+ It covers both released code and weights.
284
+
apex/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ apex.egg-info
2
+ dist
3
+ build
4
+ docs/build
5
+ *~
apex/.nojekyll ADDED
File without changes
apex/LICENSE ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ All rights reserved.
2
+
3
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4
+
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+
9
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
apex/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Introduction
2
+
3
+ This repository holds NVIDIA-maintained utilities to streamline
4
+ mixed precision and distributed training in Pytorch.
5
+ Some of the code here will be included in upstream Pytorch eventually.
6
+ The intention of Apex is to make up-to-date utilities available to
7
+ users as quickly as possible.
8
+
9
+ ## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
10
+
11
+ # Contents
12
+
13
+ ## 1. Amp: Automatic Mixed Precision
14
+
15
+ `apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
16
+ Users can easily experiment with different pure and mixed precision training modes by supplying
17
+ different flags to `amp.initialize`.
18
+
19
+ [Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
20
+ (The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
21
+
22
+ [API Documentation](https://nvidia.github.io/apex/amp.html)
23
+
24
+ [Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
25
+
26
+ [DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
27
+
28
+ [Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
29
+
30
+ ## 2. Distributed Training
31
+
32
+ `apex.parallel.DistributedDataParallel` is a module wrapper, similar to
33
+ `torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
34
+ optimized for NVIDIA's NCCL communication library.
35
+
36
+ [API Documentation](https://nvidia.github.io/apex/parallel.html)
37
+
38
+ [Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
39
+
40
+ [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
41
+
42
+ The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
43
+ shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
44
+
45
+ ### Synchronized Batch Normalization
46
+
47
+ `apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
48
+ support synchronized BN.
49
+ It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
50
+ Synchronous BN has been used in cases where only a small
51
+ local minibatch can fit on each GPU.
52
+ Allreduced stats increase the effective batch size for the BN layer to the
53
+ global batch size across all processes (which, technically, is the correct
54
+ formulation).
55
+ Synchronous BN has been observed to improve converged accuracy in some of our research models.
56
+
57
+ # Requirements
58
+
59
+ Python 3
60
+
61
+ CUDA 9 or newer
62
+
63
+ PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.
64
+
65
+ We recommend the latest stable release, obtainable from
66
+ [https://pytorch.org/](https://pytorch.org/). We also test against the latest master branch, obtainable from [https://github.com/pytorch/pytorch](https://github.com/pytorch/pytorch).
67
+
68
+ It's often convenient to use Apex in Docker containers. Compatible options include:
69
+ * [NVIDIA Pytorch containers from NGC](https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch), which come with Apex preinstalled. To use the latest Amp API, you may need to `pip uninstall apex` then reinstall Apex using the **Quick Start** commands below.
70
+ * [official Pytorch -devel Dockerfiles](https://hub.docker.com/r/pytorch/pytorch/tags), e.g. `docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7`, in which you can install Apex using the **Quick Start** commands.
71
+
72
+ See the [Docker example folder](https://github.com/NVIDIA/apex/tree/master/examples/docker) for details.
73
+
74
+ # Quick Start
75
+
76
+ ### Linux
77
+
78
+ For performance and full functionality, we recommend installing Apex with
79
+ CUDA and C++ extensions via
80
+ ```
81
+ $ git clone https://github.com/NVIDIA/apex
82
+ $ cd apex
83
+ $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
84
+ ```
85
+
86
+ Apex also supports a Python-only build (required with Pytorch 0.4) via
87
+ ```
88
+ $ pip install -v --no-cache-dir .
89
+ ```
90
+ A Python-only build omits:
91
+ - Fused kernels required to use `apex.optimizers.FusedAdam`.
92
+ - Fused kernels required to use `apex.normalization.FusedLayerNorm`.
93
+ - Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
94
+ - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
95
+ `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
96
+
97
+ ### Windows support
98
+ Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
99
+ on your system. `pip install -v --no-cache-dir .` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
apex/apex.patch ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/csrc/fused_adam_cuda_kernel.cu b/csrc/fused_adam_cuda_kernel.cu
2
+ index 34f7aa2..95581d1 100644
3
+ --- a/csrc/fused_adam_cuda_kernel.cu
4
+ +++ b/csrc/fused_adam_cuda_kernel.cu
5
+ @@ -19,8 +19,8 @@ typedef enum{
6
+
7
+ template <typename T, typename GRAD_T>
8
+ __global__ void adam_cuda_kernel(
9
+ - T* __restrict__ p,
10
+ - GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
11
+ + GRAD_T* __restrict__ p,
12
+ + T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
13
+ T* __restrict__ m,
14
+ T* __restrict__ v,
15
+ const GRAD_T * __restrict__ g,
16
+ @@ -50,7 +50,7 @@ __global__ void adam_cuda_kernel(
17
+ else // Mode 1
18
+ denom = sqrtf(v[j]) + eps;
19
+ float update = (m[j]/denom) + (decay*p[j]);
20
+ - p[j] = p[j] - (step_size*update);
21
+ + p[j] = (GRAD_T) (p[j] - (step_size*update));
22
+ if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
23
+ }
24
+ }
25
+ @@ -93,14 +93,14 @@ void fused_adam_cuda(
26
+
27
+ if (g.scalar_type() == at::ScalarType::Half) {
28
+ //all other values should be fp32 for half gradients
29
+ - AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
30
+ +// AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
31
+ //dispatch is done on the gradient type
32
+ using namespace at; // prevents "toString is undefined" errors
33
+ DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
34
+ using accscalar_t = at::acc_type<scalar_t_0, true>;
35
+ adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
36
+ - p.data<accscalar_t>(),
37
+ - p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
38
+ + p.data<scalar_t_0>(),
39
+ + NULL, //don't output p_copy for fp32, it's wasted write
40
+ m.data<accscalar_t>(),
41
+ v.data<accscalar_t>(),
42
+ g.data<scalar_t_0>(),
apex/apex/RNN/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Under construction...
apex/apex/RNN/RNNBackend.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+
5
+ import torch.nn.functional as F
6
+
7
+ import math
8
+
9
+
10
+ def is_iterable(maybe_iterable):
11
+ return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple)
12
+
13
+
14
+ def flatten_list(tens_list):
15
+ """
16
+ flatten_list
17
+ """
18
+ if not is_iterable(tens_list):
19
+ return tens_list
20
+
21
+ return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() )
22
+
23
+
24
+ #These modules always assumes batch_first
25
+ class bidirectionalRNN(nn.Module):
26
+ """
27
+ bidirectionalRNN
28
+ """
29
+ def __init__(self, inputRNN, num_layers=1, dropout = 0):
30
+ super(bidirectionalRNN, self).__init__()
31
+ self.dropout = dropout
32
+ self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout)
33
+ self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout)
34
+ self.rnns = nn.ModuleList([self.fwd, self.bckwrd])
35
+
36
+ #collect hidden option will return all hidden/cell states from entire RNN
37
+ def forward(self, input, collect_hidden=False):
38
+ """
39
+ forward()
40
+ """
41
+ seq_len = input.size(0)
42
+ bsz = input.size(1)
43
+
44
+ fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden))
45
+ bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden))
46
+
47
+ output = torch.cat( [fwd_out, bckwrd_out], -1 )
48
+ hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) )
49
+
50
+ return output, hiddens
51
+
52
+ def reset_parameters(self):
53
+ """
54
+ reset_parameters()
55
+ """
56
+ for rnn in self.rnns:
57
+ rnn.reset_parameters()
58
+
59
+ def init_hidden(self, bsz):
60
+ """
61
+ init_hidden()
62
+ """
63
+ for rnn in self.rnns:
64
+ rnn.init_hidden(bsz)
65
+
66
+ def detach_hidden(self):
67
+ """
68
+ detach_hidden()
69
+ """
70
+ for rnn in self.rnns:
71
+ rnn.detachHidden()
72
+
73
+ def reset_hidden(self, bsz):
74
+ """
75
+ reset_hidden()
76
+ """
77
+ for rnn in self.rnns:
78
+ rnn.reset_hidden(bsz)
79
+
80
+ def init_inference(self, bsz):
81
+ """
82
+ init_inference()
83
+ """
84
+ for rnn in self.rnns:
85
+ rnn.init_inference(bsz)
86
+
87
+
88
+ #assumes hidden_state[0] of inputRNN is output hidden state
89
+ #constructor either takes an RNNCell or list of RNN layers
90
+ class stackedRNN(nn.Module):
91
+ """
92
+ stackedRNN
93
+ """
94
+ def __init__(self, inputRNN, num_layers=1, dropout=0):
95
+ super(stackedRNN, self).__init__()
96
+
97
+ self.dropout = dropout
98
+
99
+ if isinstance(inputRNN, RNNCell):
100
+ self.rnns = [inputRNN]
101
+ for i in range(num_layers-1):
102
+ self.rnns.append(inputRNN.new_like(inputRNN.output_size))
103
+ elif isinstance(inputRNN, list):
104
+ assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
105
+ self.rnns=inputRNN
106
+ else:
107
+ raise RuntimeError()
108
+
109
+ self.nLayers = len(self.rnns)
110
+
111
+ self.rnns = nn.ModuleList(self.rnns)
112
+
113
+
114
+ '''
115
+ Returns output as hidden_state[0] Tensor([sequence steps][batch size][features])
116
+ If collect hidden will also return Tuple(
117
+ [n_hidden_states][sequence steps] Tensor([layer][batch size][features])
118
+ )
119
+ If not collect hidden will also return Tuple(
120
+ [n_hidden_states] Tensor([layer][batch size][features])
121
+ '''
122
+ def forward(self, input, collect_hidden=False, reverse=False):
123
+ """
124
+ forward()
125
+ """
126
+ seq_len = input.size(0)
127
+ bsz = input.size(1)
128
+ inp_iter = reversed(range(seq_len)) if reverse else range(seq_len)
129
+
130
+ hidden_states = [[] for i in range(self.nLayers)]
131
+ outputs = []
132
+
133
+ for seq in inp_iter:
134
+ for layer in range(self.nLayers):
135
+
136
+ if layer == 0:
137
+ prev_out = input[seq]
138
+
139
+ outs = self.rnns[layer](prev_out)
140
+
141
+ if collect_hidden:
142
+ hidden_states[layer].append(outs)
143
+ elif seq == seq_len-1:
144
+ hidden_states[layer].append(outs)
145
+
146
+ prev_out = outs[0]
147
+
148
+ outputs.append(prev_out)
149
+
150
+ if reverse:
151
+ outputs = list(reversed(outputs))
152
+ '''
153
+ At this point outputs is in format:
154
+ list( [seq_length] x Tensor([bsz][features]) )
155
+ need to convert it to:
156
+ list( Tensor([seq_length][bsz][features]) )
157
+ '''
158
+ output = flatten_list(outputs)
159
+
160
+ '''
161
+ hidden_states at this point is in format:
162
+ list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) )
163
+ need to convert it to:
164
+ For not collect hidden:
165
+ list( [hidden_states] x Tensor([layer][bsz][features]) )
166
+ For collect hidden:
167
+ list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
168
+ '''
169
+ if not collect_hidden:
170
+ seq_len = 1
171
+ n_hid = self.rnns[0].n_hidden_states
172
+ new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ]
173
+
174
+
175
+ for i in range(n_hid):
176
+ for j in range(seq_len):
177
+ for k in range(self.nLayers):
178
+ new_hidden[i][j][k] = hidden_states[k][j][i]
179
+
180
+ hidden_states = new_hidden
181
+ #Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) )
182
+ #Reverse seq_length if reverse
183
+ if reverse:
184
+ hidden_states = list( list(reversed(list(entry))) for entry in hidden_states)
185
+
186
+ #flatten layer dimension into tensor
187
+ hiddens = list( list(
188
+ flatten_list(seq) for seq in hidden )
189
+ for hidden in hidden_states )
190
+
191
+ #Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) )
192
+ #Remove seq_length dimension if not collect_hidden
193
+ if not collect_hidden:
194
+ hidden_states = list( entry[0] for entry in hidden_states)
195
+ return output, hidden_states
196
+
197
+ def reset_parameters(self):
198
+ """
199
+ reset_parameters()
200
+ """
201
+ for rnn in self.rnns:
202
+ rnn.reset_parameters()
203
+
204
+ def init_hidden(self, bsz):
205
+ """
206
+ init_hidden()
207
+ """
208
+ for rnn in self.rnns:
209
+ rnn.init_hidden(bsz)
210
+
211
+ def detach_hidden(self):
212
+ """
213
+ detach_hidden()
214
+ """
215
+ for rnn in self.rnns:
216
+ rnn.detach_hidden()
217
+
218
+ def reset_hidden(self, bsz):
219
+ """
220
+ reset_hidden()
221
+ """
222
+ for rnn in self.rnns:
223
+ rnn.reset_hidden(bsz)
224
+
225
+ def init_inference(self, bsz):
226
+ """
227
+ init_inference()
228
+ """
229
+ for rnn in self.rnns:
230
+ rnn.init_inference(bsz)
231
+
232
+ class RNNCell(nn.Module):
233
+ """
234
+ RNNCell
235
+ gate_multiplier is related to the architecture you're working with
236
+ For LSTM-like it will be 4 and GRU-like will be 3.
237
+ Always assumes input is NOT batch_first.
238
+ Output size that's not hidden size will use output projection
239
+ Hidden_states is number of hidden states that are needed for cell
240
+ if one will go directly to cell as tensor, if more will go as list
241
+ """
242
+ def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None):
243
+ super(RNNCell, self).__init__()
244
+
245
+ self.gate_multiplier = gate_multiplier
246
+ self.input_size = input_size
247
+ self.hidden_size = hidden_size
248
+ self.cell = cell
249
+ self.bias = bias
250
+ self.output_size = output_size
251
+ if output_size is None:
252
+ self.output_size = hidden_size
253
+
254
+ self.gate_size = gate_multiplier * self.hidden_size
255
+ self.n_hidden_states = n_hidden_states
256
+
257
+ self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size))
258
+ self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size))
259
+
260
+ #Check if there's recurrent projection
261
+ if(self.output_size != self.hidden_size):
262
+ self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))
263
+
264
+ self.b_ih = self.b_hh = None
265
+ if self.bias:
266
+ self.b_ih = nn.Parameter(torch.Tensor(self.gate_size))
267
+ self.b_hh = nn.Parameter(torch.Tensor(self.gate_size))
268
+
269
+ #hidden states for forward
270
+ self.hidden = [ None for states in range(self.n_hidden_states)]
271
+
272
+ self.reset_parameters()
273
+
274
+ def new_like(self, new_input_size=None):
275
+ """
276
+ new_like()
277
+ """
278
+ if new_input_size is None:
279
+ new_input_size = self.input_size
280
+
281
+ return type(self)(self.gate_multiplier,
282
+ new_input_size,
283
+ self.hidden_size,
284
+ self.cell,
285
+ self.n_hidden_states,
286
+ self.bias,
287
+ self.output_size)
288
+
289
+
290
+ #Use xavier where we can (weights), otherwise use uniform (bias)
291
+ def reset_parameters(self, gain=1):
292
+ """
293
+ reset_parameters()
294
+ """
295
+ stdev = 1.0 / math.sqrt(self.hidden_size)
296
+ for param in self.parameters():
297
+ param.data.uniform_(-stdev, stdev)
298
+ '''
299
+ Xavier reset:
300
+ def reset_parameters(self, gain=1):
301
+ stdv = 1.0 / math.sqrt(self.gate_size)
302
+
303
+ for param in self.parameters():
304
+ if (param.dim() > 1):
305
+ torch.nn.init.xavier_normal(param, gain)
306
+ else:
307
+ param.data.uniform_(-stdv, stdv)
308
+ '''
309
+ def init_hidden(self, bsz):
310
+ """
311
+ init_hidden()
312
+ """
313
+ for param in self.parameters():
314
+ if param is not None:
315
+ a_param = param
316
+ break
317
+
318
+ for i, _ in enumerate(self.hidden):
319
+ if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz):
320
+
321
+ if i==0:
322
+ hidden_size = self.output_size
323
+ else:
324
+ hidden_size = self.hidden_size
325
+
326
+ tens = a_param.data.new(bsz, hidden_size).zero_()
327
+ self.hidden[i] = Variable(tens, requires_grad=False)
328
+
329
+
330
+ def reset_hidden(self, bsz):
331
+ """
332
+ reset_hidden()
333
+ """
334
+ for i, _ in enumerate(self.hidden):
335
+ self.hidden[i] = None
336
+ self.init_hidden(bsz)
337
+
338
+ def detach_hidden(self):
339
+ """
340
+ detach_hidden()
341
+ """
342
+ for i, _ in enumerate(self.hidden):
343
+ if self.hidden[i] is None:
344
+ raise RuntimeError("Must initialize hidden state before you can detach it")
345
+ for i, _ in enumerate(self.hidden):
346
+ self.hidden[i] = self.hidden[i].detach()
347
+
348
+ def forward(self, input):
349
+ """
350
+ forward()
351
+ if not inited or bsz has changed this will create hidden states
352
+ """
353
+ self.init_hidden(input.size()[0])
354
+
355
+ hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
356
+ self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh)
357
+ if(self.n_hidden_states > 1):
358
+ self.hidden = list(self.hidden)
359
+ else:
360
+ self.hidden=[self.hidden]
361
+
362
+ if self.output_size != self.hidden_size:
363
+ self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
364
+
365
+ return tuple(self.hidden)
apex/apex/RNN/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import LSTM, GRU, ReLU, Tanh, mLSTM
2
+
3
+ __all__ = ['models']
apex/apex/RNN/cells.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .RNNBackend import RNNCell
6
+
7
+ from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend
8
+
9
+ import math
10
+
11
+
12
+ class mLSTMRNNCell(RNNCell):
13
+ """
14
+ mLSTMRNNCell
15
+ """
16
+
17
+ def __init__(self, input_size, hidden_size, bias = False, output_size = None):
18
+ gate_multiplier = 4
19
+ super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size)
20
+
21
+ self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size))
22
+ self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size))
23
+
24
+ self.reset_parameters()
25
+
26
+ def forward(self, input):
27
+ """
28
+ mLSTMRNNCell.forward()
29
+ """
30
+ #if not inited or bsz has changed this will create hidden states
31
+ self.init_hidden(input.size()[0])
32
+
33
+ hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden
34
+
35
+ self.hidden = list(
36
+ self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh,
37
+ b_ih=self.b_ih, b_hh=self.b_hh)
38
+ )
39
+
40
+ if self.output_size != self.hidden_size:
41
+ self.hidden[0] = F.linear(self.hidden[0], self.w_ho)
42
+ return tuple(self.hidden)
43
+
44
+
45
+ def new_like(self, new_input_size=None):
46
+ if new_input_size is None:
47
+ new_input_size = self.input_size
48
+
49
+ return type(self)(
50
+ new_input_size,
51
+ self.hidden_size,
52
+ self.bias,
53
+ self.output_size)
54
+
55
+ def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None):
56
+ """
57
+ mLSTMCell
58
+ """
59
+
60
+ if input.is_cuda:
61
+ igates = F.linear(input, w_ih)
62
+ m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
63
+ hgates = F.linear(m, w_hh)
64
+
65
+ state = fusedBackend.LSTMFused.apply
66
+ return state(igates, hgates, hidden[1], b_ih, b_hh)
67
+
68
+ hx, cx = hidden
69
+
70
+ m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh)
71
+ gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh)
72
+
73
+ ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
74
+
75
+ ingate = F.sigmoid(ingate)
76
+ forgetgate = F.sigmoid(forgetgate)
77
+ cellgate = F.tanh(cellgate)
78
+ outgate = F.sigmoid(outgate)
79
+
80
+ cy = (forgetgate * cx) + (ingate * cellgate)
81
+ hy = outgate * F.tanh(cy)
82
+
83
+ return hy, cy
84
+
apex/apex/RNN/models.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell
4
+
5
+ from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell
6
+ from .cells import mLSTMRNNCell, mLSTMCell
7
+
8
+ def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0):
9
+ """
10
+ :class:`toRNNBackend`
11
+ """
12
+
13
+ if bidirectional:
14
+ return bidirectionalRNN(inputRNN, num_layers, dropout = dropout)
15
+ else:
16
+ return stackedRNN(inputRNN, num_layers, dropout = dropout)
17
+
18
+
19
+ def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
20
+ """
21
+ :class:`LSTM`
22
+ """
23
+ inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size)
24
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
25
+
26
+ def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
27
+ """
28
+ :class:`GRU`
29
+ """
30
+ inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size)
31
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
32
+
33
+ def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
34
+ """
35
+ :class:`ReLU`
36
+ """
37
+ inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size)
38
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
39
+
40
+ def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
41
+ """
42
+ :class:`Tanh`
43
+ """
44
+ inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size)
45
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
46
+
47
+ def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None):
48
+ """
49
+ :class:`mLSTM`
50
+ """
51
+ inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size)
52
+ return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout)
53
+
54
+
apex/apex/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import parallel
2
+ from . import amp
3
+ from . import fp16_utils
4
+
5
+ # For optimizers and normalization there is no Python fallback.
6
+ # Absence of cuda backend is a hard error.
7
+ # I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
8
+ # to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
9
+ # so they expect those backends to be available, but for some reason they actually aren't
10
+ # available (for example because they built improperly in a way that isn't revealed until
11
+ # load time) the error message is timely and visible.
12
+ from . import optimizers
13
+ from . import normalization
apex/apex/amp/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # amp: Automatic Mixed Precision
2
+
3
+ ## Annotating User Functions
4
+
5
+ Nearly all PyTorch user code needs nothing more than the two steps
6
+ above to use amp. After all, custom layers are built out of simpler
7
+ PyTorch components, and amp already can see those.
8
+
9
+ However, any custom C++ or CUDA code is outside of amp's (default)
10
+ view of things. For example, suppose I implemented a new recurrent
11
+ cell called a "forgetful recurrent unit" that calls directly into a
12
+ CUDA backend:
13
+
14
+ ```python
15
+ from backend import FRUBackend
16
+
17
+ def fru(input, hidden, weight, bias):
18
+ # call to CUDA code
19
+ FRUBackend(input, hidden, weight, bias)
20
+ ```
21
+
22
+ In this case, it is possible to get a runtime type mismatch. For
23
+ example, you might have `input` in fp16, and `weight` in fp32, and amp
24
+ doesn't have the visibility to insert an appropriate cast.
25
+
26
+ amp exposes two ways to handle "invisible" backend code: function
27
+ annotations and explicit registration.
28
+
29
+ #### Function annotation
30
+
31
+ The first way to handle backend code is a set of function annotations:
32
+
33
+ - `@amp.half_function`
34
+ - `@amp.float_function`
35
+ - `@amp.promote_function`
36
+
37
+ These correspond to:
38
+
39
+ - Cast all arguments to fp16
40
+ - Cast all argumnets fo fp32
41
+ - If there are any type mismatches, cast everything to the widest type
42
+
43
+ In our example, we believe that the FRU unit is fp16-safe and will get
44
+ performance gains from casting its arguments to fp16, so we write:
45
+
46
+ ```python
47
+ @amp.half_function
48
+ def fru(input, hidden, weight, bias):
49
+ #...
50
+ ```
51
+
52
+ #### Explicit registration
53
+
54
+ The other way to handle backend code is with explicit function
55
+ registration:
56
+
57
+ - `amp.register_half_function(module, function_name)`
58
+ - `amp.register_float_function(module, function_name)`
59
+ - `amp.register_promote_function(module, function_name)`
60
+
61
+ When using this API, `module` is the containing class or module for
62
+ the function, and `function_name` is the _string_ name of the
63
+ function. Note that the function must be registered before the call to
64
+ `amp.initalize()`.
65
+
66
+ For our FRU unit, we can register the backend function directly:
67
+
68
+ ```python
69
+ import backend
70
+
71
+ amp.register_half_function(backend, 'FRUBackend')
72
+ ```
apex/apex/amp/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .amp import init, half_function, float_function, promote_function,\
2
+ register_half_function, register_float_function, register_promote_function
3
+ from .handle import scale_loss, disable_casts
4
+ from .frontend import initialize
5
+ from ._amp_state import master_params, _amp_state
apex/apex/amp/__version__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ VERSION = (0, 1, 0)
2
+ __version__ = '.'.join(map(str, VERSION))
apex/apex/amp/_amp_state.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a "header object" that allows different amp modules to communicate.
2
+ # I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like.
3
+ # But apparently it's ok:
4
+ # http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm
5
+ import os
6
+ import torch
7
+
8
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
9
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
10
+
11
+ if TORCH_MAJOR == 0:
12
+ import collections.abc as container_abcs
13
+ else:
14
+ from torch._six import container_abcs
15
+
16
+
17
+ class AmpState(object):
18
+ def __init__(self):
19
+ self.hard_override=False
20
+ self.allow_incoming_model_not_fp32 = False
21
+ self.verbosity=1
22
+
23
+
24
+ # Attribute stash. Could also just stash things as global module attributes.
25
+ _amp_state = AmpState()
26
+
27
+
28
+ def warn_or_err(msg):
29
+ if _amp_state.hard_override:
30
+ print("Warning: " + msg)
31
+ else:
32
+ raise RuntimeError(msg)
33
+ # I'm not sure if allowing hard_override is a good idea.
34
+ # + " If you're sure you know what you're doing, supply " +
35
+ # "hard_override=True to amp.initialize.")
36
+
37
+
38
+ distributed = False
39
+ if 'WORLD_SIZE' in os.environ:
40
+ distributed = int(os.environ['WORLD_SIZE']) > 1
41
+
42
+
43
+ def maybe_print(msg, rank0=False):
44
+ if _amp_state.verbosity > 0:
45
+ if rank0:
46
+ if distributed:
47
+ if torch.distributed.get_rank() == 0:
48
+ print(msg)
49
+ else:
50
+ print(msg)
51
+ else:
52
+ print(msg)
53
+
54
+
55
+ # def iter_params(param_groups):
56
+ # for group in param_groups:
57
+ # for p in group['params']:
58
+ # yield p
59
+
60
+
61
+ def master_params(optimizer):
62
+ """
63
+ Generator expression that iterates over the params owned by ``optimizer``.
64
+
65
+ Args:
66
+ optimizer: An optimizer previously returned from ``amp.initialize``.
67
+ """
68
+ for group in optimizer.param_groups:
69
+ for p in group['params']:
70
+ yield p
apex/apex/amp/_initialize.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._six import string_classes
3
+ import functools
4
+ import numpy as np
5
+ import warnings
6
+ from ._amp_state import _amp_state, warn_or_err, container_abcs
7
+ from .handle import disable_casts
8
+ from .scaler import LossScaler
9
+ from ._process_optimizer import _process_optimizer
10
+ from apex.fp16_utils import convert_network
11
+ from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
12
+ from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
13
+ from ..optimizers import FusedAdam
14
+ from ..parallel import DistributedDataParallel as apex_DDP
15
+ from ..parallel.LARC import LARC
16
+
17
+
18
+ def to_type(dtype, t):
19
+ if isinstance(t, torch.Tensor):
20
+ if not t.is_cuda:
21
+ # This should not be a hard error, since it may be legitimate.
22
+ warnings.warn("An input tensor was not cuda.")
23
+ # GANs require this.
24
+ # if t.requires_grad:
25
+ # warn_or_err("input data requires grad. Since input data is not a model parameter,\n"
26
+ # "its gradients will not be properly allreduced by DDP.")
27
+ if t.is_floating_point():
28
+ return t.to(dtype)
29
+ return t
30
+ else:
31
+ # Trust the user's custom batch type, that's all I can do here.
32
+ return t.to(dtype)
33
+
34
+
35
+ # Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py.
36
+ def applier(value, fn):
37
+ if isinstance(value, torch.Tensor):
38
+ return fn(value)
39
+ elif isinstance(value, string_classes):
40
+ return value
41
+ elif isinstance(value, np.ndarray):
42
+ return value
43
+ elif hasattr(value, "to"): # Allow handling of custom batch classes
44
+ return fn(value)
45
+ elif isinstance(value, container_abcs.Mapping):
46
+ return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
47
+ elif isinstance(value, container_abcs.Iterable):
48
+ return type(value)(applier(v, fn) for v in value)
49
+ else:
50
+ # Do I want this to fire off even if someone chooses to pass something ordinary like
51
+ # an int or float? May be more annoying than it's worth.
52
+ # print("Warning: unrecognized type in applier. If your input data is a custom class, "
53
+ # "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
54
+ # "Amp will check for your custom to() and invoke it to cast the batch's "
55
+ # "floating-point Tensors to the appropriate type. "
56
+ # "Also, if your data is a custom class, it is your responsibility to ensure that "
57
+ # "any Tensors you want to be cuda are already cuda."
58
+ return value
59
+
60
+
61
+ def check_models(models):
62
+ for model in models:
63
+ parallel_type = None
64
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
65
+ parallel_type = "torch.nn.parallel.DistributedDataParallel"
66
+ if isinstance(model, apex_DDP):
67
+ parallel_type = "apex.parallel.DistributedDataParallel"
68
+ if isinstance(model, torch.nn.parallel.DataParallel):
69
+ parallel_type = "torch.nn.parallel.DataParallel"
70
+ if parallel_type is not None:
71
+ raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
72
+ "Parallel wrappers should only be applied to the model(s) AFTER \n"
73
+ "the model(s) have been returned from amp.initialize.")
74
+
75
+
76
+ def check_params_fp32(models):
77
+ for model in models:
78
+ for name, param in model.named_parameters():
79
+ if param.is_floating_point():
80
+ if 'Half' in param.type():
81
+ warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
82
+ "When using amp.initialize, you do not need to call .half() on your model\n"
83
+ "before passing it, no matter what optimization level you choose.".format(
84
+ name, param.type()))
85
+ elif not param.is_cuda:
86
+ warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
87
+ "When using amp.initialize, you need to provide a model with parameters\n"
88
+ "located on a CUDA device before passing it no matter what optimization level\n"
89
+ "you chose. Use model.to('cuda') to use the default device.".format(
90
+ name, param.type()))
91
+
92
+ # Backward compatibility for PyTorch 0.4
93
+ if hasattr(model, 'named_buffers'):
94
+ buf_iter = model.named_buffers()
95
+ else:
96
+ buf_iter = model._buffers
97
+ for obj in buf_iter:
98
+ if type(obj)==tuple:
99
+ name, buf = obj
100
+ else:
101
+ name, buf = obj, buf_iter[obj]
102
+ if buf.is_floating_point():
103
+ if 'Half' in buf.type():
104
+ warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
105
+ "When using amp.initialize, you do not need to call .half() on your model\n"
106
+ "before passing it, no matter what optimization level you choose.".format(
107
+ name, buf.type()))
108
+ elif not buf.is_cuda:
109
+ warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
110
+ "When using amp.initialize, you need to provide a model with buffers\n"
111
+ "located on a CUDA device before passing it no matter what optimization level\n"
112
+ "you chose. Use model.to('cuda') to use the default device.".format(
113
+ name, buf.type()))
114
+
115
+
116
+ def check_optimizers(optimizers):
117
+ for optim in optimizers:
118
+ bad_optim_type = None
119
+ if isinstance(optim, FP16_Optimizer_general):
120
+ bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
121
+ if isinstance(optim, FP16_Optimizer_for_fused):
122
+ bad_optim_type = "apex.optimizers.FP16_Optimizer"
123
+ if bad_optim_type is not None:
124
+ raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
125
+ "The optimizer(s) passed to amp.initialize() must be bare \n"
126
+ "instances of either ordinary Pytorch optimizers, or Apex fused \n"
127
+ "optimizers (currently just FusedAdam, but FusedSGD will be added \n"
128
+ "soon). You should not manually wrap your optimizer in either \n"
129
+ "apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
130
+ "amp.initialize will take care of that for you (if necessary) based \n"
131
+ "on the specified opt_level (and optional overridden properties).")
132
+
133
+
134
+ def wrap_fused_adam(optimizer, properties):
135
+ msg = 'Currently, the usage of FusedAdam is restricted to '\
136
+ 'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '\
137
+ 'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
138
+
139
+ assert properties.master_weights is True, msg
140
+ assert properties.cast_model_type is torch.float16, msg
141
+ assert (properties.keep_batchnorm_fp32 is False or
142
+ properties.keep_batchnorm_fp32 is None), msg
143
+
144
+ if properties.loss_scale == "dynamic":
145
+ return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
146
+ else:
147
+ return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
148
+
149
+
150
+ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
151
+ from apex.parallel import DistributedDataParallel as apex_DDP
152
+ from .amp import init as amp_init
153
+
154
+ optimizers_was_list = False
155
+ if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
156
+ optimizers = [optimizers]
157
+ elif optimizers is None:
158
+ optimizers = []
159
+ elif isinstance(optimizers, list):
160
+ optimizers_was_list = True
161
+ check_optimizers(optimizers)
162
+ else:
163
+ check_optimizers([optimizers])
164
+ raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")
165
+
166
+ if isinstance(models, torch.nn.Module):
167
+ models_was_list = False
168
+ models = [models]
169
+ elif isinstance(models, list):
170
+ models_was_list = True
171
+ else:
172
+ raise TypeError("models must be either a single model or a list of models.")
173
+
174
+ check_models(models)
175
+
176
+ if not _amp_state.allow_incoming_model_not_fp32:
177
+ check_params_fp32(models)
178
+
179
+
180
+ # In the future, when FP16_Optimizer can be deprecated and master weights can
181
+ # become an attribute, remember to stash master weights before casting the model.
182
+
183
+ if properties.cast_model_type:
184
+ if properties.keep_batchnorm_fp32:
185
+ for model in models:
186
+ convert_network(model, properties.cast_model_type)
187
+ else:
188
+ for model in models:
189
+ model.to(properties.cast_model_type)
190
+
191
+ input_caster = functools.partial(to_type, properties.cast_model_type)
192
+ if cast_model_outputs is not None:
193
+ output_caster = functools.partial(to_type, cast_model_outputs)
194
+ else:
195
+ output_caster = functools.partial(to_type, torch.float32)
196
+
197
+ for model in models:
198
+ # Patch the forward method to cast incoming data to the correct type, and
199
+ # outgoing data to float32, so "the user never needs to call .half()."
200
+ # I like writing things explicitly more than decorators.
201
+ def patch_forward(old_fwd):
202
+ def new_fwd(*args, **kwargs):
203
+ output = old_fwd(*applier(args, input_caster),
204
+ **applier(kwargs, input_caster))
205
+ return applier(output, output_caster)
206
+ return new_fwd
207
+
208
+ model.forward = patch_forward(model.forward)
209
+
210
+ # State dict trick to recast any preexisting per-param state tensors
211
+ for optimizer in optimizers:
212
+ optimizer.load_state_dict(optimizer.state_dict())
213
+ elif cast_model_outputs is not None:
214
+ output_caster = functools.partial(to_type, cast_model_outputs)
215
+
216
+ for model in models:
217
+ def patch_forward(old_fwd):
218
+ def new_fwd(*args, **kwargs):
219
+ output = old_fwd(*args, **kwargs)
220
+ return applier(output, output_caster)
221
+ return new_fwd
222
+
223
+ model.forward = patch_forward(model.forward)
224
+
225
+ for i, optimizer in enumerate(optimizers):
226
+ # Still need to special case this for the first pass
227
+ if isinstance(optimizer, FusedAdam):
228
+ optimizers[i] = wrap_fused_adam(optimizer, properties)
229
+ else:
230
+ optimizers[i] = _process_optimizer(optimizer, properties)
231
+
232
+ _amp_state.loss_scalers = []
233
+ for _ in range(num_losses):
234
+ _amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
235
+ min_loss_scale=_amp_state.min_loss_scale,
236
+ max_loss_scale=_amp_state.max_loss_scale))
237
+
238
+ if properties.patch_torch_functions:
239
+ # handle is unused here. It's accessible later through a global value anyway.
240
+ handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
241
+ for optimizer in optimizers:
242
+ # Disable Amp casting for the optimizer step, because it should only be
243
+ # applied to FP32 master params anyway.
244
+ def patch_step(old_step):
245
+ def new_step(*args, **kwargs):
246
+ with disable_casts():
247
+ output = old_step(*args, **kwargs)
248
+ return output
249
+ return new_step
250
+
251
+ optimizer.step = patch_step(optimizer.step)
252
+
253
+ if optimizers_was_list:
254
+ if models_was_list:
255
+ return models, optimizers
256
+ else:
257
+ return models[0], optimizers
258
+ else:
259
+ if models_was_list:
260
+ if len(optimizers) == 0:
261
+ return models
262
+ else:
263
+ return models, optimizers[0]
264
+ else:
265
+ if len(optimizers) == 0:
266
+ return models[0]
267
+ else:
268
+ return models[0], optimizers[0]
apex/apex/amp/_process_optimizer.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ from ..fp16_utils import master_params_to_model_params
3
+ from ..multi_tensor_apply import multi_tensor_applier
4
+ from ._amp_state import maybe_print
5
+ import torch
6
+
7
+
8
+ class AmpOptimizerState(object):
9
+ def __init__(self):
10
+ pass
11
+
12
+
13
+ def lazy_init_with_master_weights(self):
14
+ stash = self._amp_stash
15
+ stash.fp16_groups = []
16
+ stash.fp32_from_fp16_groups = []
17
+ stash.fp32_from_fp32_groups = []
18
+ for i, param_group in enumerate(self.param_groups):
19
+ # maybe_print("FP16_Optimizer processing param group {}:".format(i))
20
+ fp16_params_this_group = []
21
+ fp32_params_this_group = []
22
+ fp32_from_fp16_params_this_group = []
23
+ for i, param in enumerate(param_group['params']):
24
+ if param.requires_grad:
25
+ if param.type() == 'torch.cuda.HalfTensor':
26
+ # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
27
+ # .format(param.size()))
28
+ fp16_params_this_group.append(param)
29
+ master_param = param.detach().clone().float()
30
+ master_param.requires_grad = True
31
+ param_group['params'][i] = master_param
32
+ fp32_from_fp16_params_this_group.append(master_param)
33
+ # Reset existing state dict key to the new master param.
34
+ # We still need to recast per-param state tensors, if any, to FP32.
35
+ if param in self.state:
36
+ self.state[master_param] = self.state.pop(param)
37
+ elif param.type() == 'torch.cuda.FloatTensor':
38
+ # maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
39
+ # .format(param.size()))
40
+ fp32_params_this_group.append(param)
41
+ param_group['params'][i] = param
42
+ else:
43
+ raise TypeError("Optimizer's parameters must be either "
44
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
45
+ "Received {}".format(param.type()))
46
+
47
+ stash.fp16_groups.append(fp16_params_this_group)
48
+ stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
49
+ stash.fp32_from_fp32_groups.append(fp32_params_this_group)
50
+
51
+ stash.all_fp16_params = []
52
+ for group in stash.fp16_groups:
53
+ stash.all_fp16_params += group
54
+
55
+ stash.all_fp32_from_fp16_params = []
56
+ for group in stash.fp32_from_fp16_groups:
57
+ stash.all_fp32_from_fp16_params += group
58
+
59
+ stash.all_fp32_from_fp32_params = []
60
+ for group in stash.fp32_from_fp32_groups:
61
+ stash.all_fp32_from_fp32_params += group
62
+
63
+ # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
64
+ stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params]
65
+
66
+ for param in stash.all_fp32_from_fp16_params:
67
+ param.grad = None
68
+
69
+ for param in stash.all_fp32_from_fp32_params:
70
+ param.grad = None
71
+
72
+ # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
73
+ self.load_state_dict(self.state_dict())
74
+
75
+
76
+ def prepare_backward_with_master_weights(self):
77
+ stash = self._amp_stash
78
+
79
+ if not stash.lazy_init_called:
80
+ self._lazy_init_maybe_master_weights()
81
+ stash.lazy_init_called = True
82
+
83
+ for i, param in enumerate(stash.all_fp16_params):
84
+ # Set up to leverage grad copy elision:
85
+ param.grad = None
86
+
87
+ # for i, param in enumerate(stash.all_fp32_from_fp16_params):
88
+ # stash.all_fp32_from_fp16_grad_stash[i] = param.grad
89
+
90
+ for i, param in enumerate(stash.all_fp32_from_fp32_params):
91
+ stash.all_fp32_from_fp32_grad_stash[i] = param.grad
92
+ # Set up to leverage grad copy elision:
93
+ param.grad = None
94
+
95
+
96
+ def post_backward_with_master_weights(self, scaler):
97
+ stash = self._amp_stash
98
+
99
+ # This is a lot of python overhead...
100
+ fp16_grads_needing_unscale = []
101
+ new_fp32_grads = []
102
+ fp16_grads_needing_unscale_with_stash = []
103
+ preexisting_fp32_grads = []
104
+ for fp16_param, fp32_param in zip(stash.all_fp16_params,
105
+ stash.all_fp32_from_fp16_params):
106
+ if fp16_param.grad is None and fp32_param.grad is not None:
107
+ continue
108
+ elif fp16_param.grad is not None and fp32_param.grad is None:
109
+ fp32_param.grad = torch.empty_like(fp32_param)
110
+ fp16_grads_needing_unscale.append(fp16_param.grad)
111
+ new_fp32_grads.append(fp32_param.grad)
112
+ elif fp16_param.grad is not None and fp32_param.grad is not None:
113
+ fp16_grads_needing_unscale_with_stash.append(fp16_param.grad)
114
+ preexisting_fp32_grads.append(fp32_param.grad)
115
+ else: # fp16_param.grad is None and fp32_param.grad is None:
116
+ continue
117
+
118
+ if len(fp16_grads_needing_unscale) > 0:
119
+ scaler.unscale(
120
+ fp16_grads_needing_unscale,
121
+ new_fp32_grads,
122
+ scaler.loss_scale(),
123
+ models_are_masters=False)
124
+
125
+ if len(fp16_grads_needing_unscale_with_stash) > 0:
126
+ scaler.unscale_with_stashed(
127
+ fp16_grads_needing_unscale_with_stash,
128
+ preexisting_fp32_grads,
129
+ preexisting_fp32_grads)
130
+
131
+ # fp32 params can be treated as they would be in the "no_master_weights" case.
132
+ grads_needing_unscale = []
133
+ grads_needing_unscale_with_stash = []
134
+ stashed = []
135
+ for param, stashed_grad in zip(stash.all_fp32_from_fp32_params,
136
+ stash.all_fp32_from_fp32_grad_stash):
137
+ if param.grad is None and stashed_grad is not None:
138
+ param.grad = stashed_grad
139
+ elif param.grad is not None and stashed_grad is None:
140
+ grads_needing_unscale.append(param.grad)
141
+ elif param.grad is not None and stashed_grad is not None:
142
+ grads_needing_unscale_with_stash.append(param.grad)
143
+ stashed.append(stashed_grad)
144
+ else: # param.grad is None and stashed_grad is None:
145
+ continue
146
+
147
+ if len(grads_needing_unscale) > 0:
148
+ scaler.unscale(
149
+ grads_needing_unscale,
150
+ grads_needing_unscale,
151
+ scaler.loss_scale(),
152
+ models_are_masters=True)
153
+
154
+ if len(grads_needing_unscale_with_stash) > 0:
155
+ scaler.unscale_with_stashed(
156
+ grads_needing_unscale_with_stash,
157
+ stashed,
158
+ grads_needing_unscale_with_stash)
159
+
160
+ # Clear the stash.
161
+ for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
162
+ stash.all_fp32_from_fp32_grad_stash[i] = None
163
+
164
+
165
+ def lazy_init_no_master_weights(self):
166
+ stash = self._amp_stash
167
+ stash.all_fp16_params = []
168
+ stash.all_fp32_params = []
169
+ for i, param_group in enumerate(self.param_groups):
170
+ for i, param in enumerate(param_group['params']):
171
+ if param.type() == 'torch.cuda.HalfTensor':
172
+ stash.all_fp16_params.append(param)
173
+ elif param.type() == 'torch.cuda.FloatTensor':
174
+ stash.all_fp32_params.append(param)
175
+ else:
176
+ raise TypeError("Optimizer's parameters must be either "
177
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
178
+ "Received {}".format(param.type()))
179
+
180
+ stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
181
+ stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
182
+
183
+
184
+ def prepare_backward_no_master_weights(self):
185
+ stash = self._amp_stash
186
+
187
+ if not stash.lazy_init_called:
188
+ self._lazy_init_maybe_master_weights()
189
+ stash.lazy_init_called = True
190
+
191
+ for i, param in enumerate(stash.all_fp16_params):
192
+ stash.all_fp16_grad_stash[i] = param.grad
193
+ # Set up to leverage grad copy elision:
194
+ param.grad = None
195
+
196
+ for i, param in enumerate(stash.all_fp32_params):
197
+ stash.all_fp32_grad_stash[i] = param.grad
198
+ # Set up to leverage grad copy elision:
199
+ param.grad = None
200
+
201
+
202
+ def post_backward_no_master_weights(self, scaler):
203
+ stash = self._amp_stash
204
+
205
+ split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash),
206
+ (stash.all_fp32_params, stash.all_fp32_grad_stash))
207
+
208
+ for params, stashed_grads in split_types:
209
+ # This is a lot of python overhead...
210
+ grads_needing_unscale = []
211
+ grads_needing_unscale_with_stash = []
212
+ stashed = []
213
+ for param, stashed_grad in zip(params, stashed_grads):
214
+ if param.grad is None and stashed_grad is not None:
215
+ param.grad = stashed_grad
216
+ elif param.grad is not None and stashed_grad is None:
217
+ grads_needing_unscale.append(param.grad)
218
+ elif param.grad is not None and stashed_grad is not None:
219
+ grads_needing_unscale_with_stash.append(param.grad)
220
+ stashed.append(stashed_grad)
221
+ else: # param.grad is None and stashed_grad is None
222
+ continue
223
+
224
+ if len(grads_needing_unscale) > 0:
225
+ scaler.unscale(
226
+ grads_needing_unscale,
227
+ grads_needing_unscale,
228
+ scaler.loss_scale(),
229
+ models_are_masters=True)
230
+
231
+ if len(grads_needing_unscale_with_stash) > 0:
232
+ scaler.unscale_with_stashed(
233
+ grads_needing_unscale_with_stash,
234
+ stashed,
235
+ grads_needing_unscale_with_stash)
236
+
237
+ # Clear the stash.
238
+ for i in range(len(stashed_grads)):
239
+ stashed_grads[i] = None
240
+
241
+
242
+ def _master_params_to_model_params(self):
243
+ stash = self._amp_stash
244
+ if multi_tensor_applier.available:
245
+ if len(stash.all_fp16_params) > 0:
246
+ multi_tensor_applier(
247
+ stash.multi_tensor_scale,
248
+ stash.dummy_overflow_buf,
249
+ [stash.all_fp32_from_fp16_params, stash.all_fp16_params],
250
+ 1.0)
251
+ else:
252
+ for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups):
253
+ master_params_to_model_params(fp16_group, fp32_from_fp16_group)
254
+
255
+
256
+ def _process_optimizer(optimizer, properties):
257
+ if hasattr(optimizer, "_amp_stash"):
258
+ raise RuntimeError("A given optimizer should only be passed through amp.initialize once.")
259
+ else:
260
+ optimizer._amp_stash = AmpOptimizerState()
261
+
262
+ optimizer._amp_stash.lazy_init_called = False
263
+ optimizer._amp_stash.already_patched = False
264
+ optimizer._amp_stash.params_have_scaled_gradients = False
265
+
266
+ for name in ("_lazy_init_maybe_master_weights",
267
+ "_master_params_to_model_params",
268
+ "_prepare_amp_backward",
269
+ "_post_amp_backward"):
270
+ if hasattr(optimizer, name):
271
+ raise RuntimeError("Incoming optimizer already has {} defined.".format(name))
272
+
273
+ # TODO: Centralize exposure and import error checking for the C backend.
274
+ if multi_tensor_applier.available:
275
+ import amp_C
276
+ optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
277
+ optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
278
+
279
+ if properties.master_weights:
280
+ optimizer._lazy_init_maybe_master_weights = types.MethodType(
281
+ lazy_init_with_master_weights, optimizer)
282
+
283
+ optimizer._master_params_to_model_params = types.MethodType(
284
+ _master_params_to_model_params, optimizer)
285
+
286
+ old_step = optimizer.step
287
+ def new_step(self, closure=None):
288
+ if closure is not None:
289
+ raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
290
+ retval = old_step()
291
+ self._master_params_to_model_params()
292
+ # Clear the master grads that wouldn't be zeroed by model.zero_grad()
293
+ for param in self._amp_stash.all_fp32_from_fp16_params:
294
+ param.grad = None
295
+ return retval
296
+ optimizer.step = types.MethodType(new_step, optimizer)
297
+
298
+ old_zero_grad = optimizer.zero_grad
299
+ def new_zero_grad(self):
300
+ stash = self._amp_stash
301
+ if not stash.lazy_init_called:
302
+ self._lazy_init_maybe_master_weights()
303
+ stash.lazy_init_called = True
304
+ # Zero the model grads.
305
+ for param in stash.all_fp16_params:
306
+ if param.grad is not None:
307
+ param.grad.detach_()
308
+ param.grad.zero_()
309
+ for param in stash.all_fp32_from_fp32_params:
310
+ if param.grad is not None:
311
+ param.grad.detach_()
312
+ param.grad.zero_()
313
+ # Clear the master grads that are independent of model grads
314
+ for param in self._amp_stash.all_fp32_from_fp16_params:
315
+ param.grad = None
316
+ optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
317
+
318
+ optimizer._prepare_amp_backward = types.MethodType(
319
+ prepare_backward_with_master_weights, optimizer)
320
+
321
+ optimizer._post_amp_backward = types.MethodType(
322
+ post_backward_with_master_weights, optimizer)
323
+ else:
324
+ optimizer._lazy_init_maybe_master_weights = types.MethodType(
325
+ lazy_init_no_master_weights, optimizer)
326
+
327
+ optimizer._prepare_amp_backward = types.MethodType(
328
+ prepare_backward_no_master_weights, optimizer)
329
+
330
+ optimizer._post_amp_backward = types.MethodType(
331
+ post_backward_no_master_weights, optimizer)
332
+
333
+ old_add_param_group = optimizer.add_param_group
334
+
335
+ def new_add_param_group(self, new_group):
336
+ stash = self._amp_stash
337
+
338
+ if not stash.lazy_init_called:
339
+ self._lazy_init_maybe_master_weights()
340
+ stash.lazy_init_called = True
341
+
342
+ assert isinstance(new_group, dict), "param group must be a dict"
343
+
344
+ new_params = new_group['params']
345
+ if isinstance(new_params, torch.Tensor):
346
+ new_group['params'] = [new_params]
347
+ elif isinstance(new_params, set):
348
+ raise TypeError('optimizer parameters need to be organized in ordered collections, but '
349
+ 'the ordering of tensors in sets will change between runs. Please use a list instead.')
350
+ else:
351
+ new_group['params'] = list(new_params)
352
+
353
+ if properties.master_weights:
354
+ # Mutate new_group in-place to use FP32 master params
355
+ fp16_params_this_group = []
356
+ fp32_params_this_group = []
357
+ fp32_from_fp16_params_this_group = []
358
+ for i, param in enumerate(new_group['params']):
359
+ if param.requires_grad:
360
+ if param.type() == 'torch.cuda.HalfTensor':
361
+ fp16_params_this_group.append(param)
362
+ master_param = param.detach().clone().float()
363
+ master_param.requires_grad = True
364
+ new_group['params'][i] = master_param
365
+ fp32_from_fp16_params_this_group.append(master_param)
366
+ elif param.type() == 'torch.cuda.FloatTensor':
367
+ fp32_params_this_group.append(param)
368
+ new_group['params'][i] = param
369
+ else:
370
+ raise TypeError("Optimizer's parameters must be either "
371
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
372
+ "Received {}".format(param.type()))
373
+
374
+ stash.fp16_groups.append(fp16_params_this_group)
375
+ stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
376
+ stash.fp32_from_fp32_groups.append(fp32_params_this_group)
377
+
378
+ stash.all_fp16_params += fp16_params_this_group
379
+ stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group
380
+ stash.all_fp32_from_fp32_params += fp32_params_this_group
381
+
382
+ # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params]
383
+ stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group]
384
+
385
+ # It should be ok to let params be added with existing .grad attributes.
386
+ # for param in fp16_params_this_group:
387
+ # param.grad = None
388
+
389
+ # for param in fp32_from_fp16_params_this_group:
390
+ # param.grad = None
391
+
392
+ # for param in stash.fp32_params_this_group:
393
+ # param.grad = None
394
+ else:
395
+ for param in new_group['params']:
396
+ if param.type() == 'torch.cuda.HalfTensor':
397
+ stash.all_fp16_params.append(param)
398
+ stash.all_fp16_grad_stash.append(None)
399
+ elif param.type() == 'torch.cuda.FloatTensor':
400
+ stash.all_fp32_params.append(param)
401
+ stash.all_fp32_grad_stash.append(None)
402
+ else:
403
+ raise TypeError("Optimizer's parameters must be either "
404
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
405
+ "Received {}".format(param.type()))
406
+
407
+ old_add_param_group(new_group)
408
+
409
+ optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer)
410
+
411
+ return optimizer
apex/apex/amp/amp.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import compat, rnn_compat, utils, wrap
2
+ from .handle import AmpHandle, NoOpHandle
3
+ from .lists import functional_overrides, torch_overrides, tensor_overrides
4
+ from ._amp_state import _amp_state
5
+ from .frontend import *
6
+
7
+ import functools
8
+ import itertools
9
+
10
+ import torch
11
+
12
+
13
+ _DECORATOR_HANDLE = None
14
+ _USER_CAST_REGISTRY = set()
15
+ _USER_PROMOTE_REGISTRY = set()
16
+
17
+
18
+ def _decorator_helper(orig_fn, cast_fn, wrap_fn):
19
+ def wrapper(*args, **kwargs):
20
+ handle = _DECORATOR_HANDLE
21
+ if handle is None or not handle.is_active():
22
+ return orig_fn(*args, **kwargs)
23
+ inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__,
24
+ handle.verbose)
25
+ return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs)
26
+ return wrapper
27
+
28
+
29
+ # Decorator form
30
+ def half_function(fn):
31
+ wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True)
32
+ return _decorator_helper(fn, utils.maybe_half, wrap_fn)
33
+
34
+
35
+ def float_function(fn):
36
+ wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False)
37
+ return _decorator_helper(fn, utils.maybe_float, wrap_fn)
38
+
39
+
40
+ def promote_function(fn):
41
+ wrap_fn = functools.partial(wrap.make_promote_wrapper)
42
+ return _decorator_helper(fn, utils.maybe_float, wrap_fn)
43
+
44
+
45
+ # Registry form
46
+ def register_half_function(module, name):
47
+ if not hasattr(module, name):
48
+ raise ValueError('No function named {} in module {}.'.format(
49
+ name, module))
50
+ _USER_CAST_REGISTRY.add((module, name, utils.maybe_half))
51
+
52
+
53
+ def register_float_function(module, name):
54
+ if not hasattr(module, name):
55
+ raise ValueError('No function named {} in module {}.'.format(
56
+ name, module))
57
+ _USER_CAST_REGISTRY.add((module, name, utils.maybe_float))
58
+
59
+
60
+ def register_promote_function(module, name):
61
+ if not hasattr(module, name):
62
+ raise ValueError('No function named {} in module {}.'.format(
63
+ name, module))
64
+ _USER_PROMOTE_REGISTRY.add((module, name))
65
+
66
+
67
+ # Top-level function to insert _all_ the hooks.
68
+ def init(enabled=True, loss_scale="dynamic", enable_caching=True, verbose=False, allow_banned=False):
69
+ global _DECORATOR_HANDLE
70
+
71
+ if not enabled:
72
+ handle = NoOpHandle()
73
+ _DECORATOR_HANDLE = handle
74
+ return handle
75
+
76
+ handle = AmpHandle(loss_scale, enable_caching, verbose)
77
+
78
+ # 0) Force-{fp16, fp32} for user-annotated functions
79
+ for mod, fn, cast_fn in _USER_CAST_REGISTRY:
80
+ try_caching = (cast_fn == utils.maybe_half)
81
+ wrap.cached_cast(mod, fn, cast_fn, handle,
82
+ try_caching, verbose)
83
+ _USER_CAST_REGISTRY.clear()
84
+
85
+ # 0.5) Force-promote for user-annotated functions
86
+ for mod, fn in _USER_PROMOTE_REGISTRY:
87
+ wrap.promote(mod, fn, handle, verbose)
88
+ _USER_PROMOTE_REGISTRY.clear()
89
+
90
+ # 1) Force-{fp16, fp32} on white- / black-list functions
91
+ override_modules = [functional_overrides,
92
+ torch_overrides,
93
+ tensor_overrides]
94
+ cast_table = [('FP16_FUNCS', utils.maybe_half),
95
+ ('FP32_FUNCS', utils.maybe_float)]
96
+ for module, (list_name, cast_fn) in itertools.product(override_modules,
97
+ cast_table):
98
+ for fn in getattr(module, list_name):
99
+ try_caching = (cast_fn == utils.maybe_half)
100
+ wrap.cached_cast(module.MODULE, fn, cast_fn, handle,
101
+ try_caching, verbose)
102
+
103
+ # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist
104
+ # methods on FloatTensor, since they're distinct types.
105
+ if compat.tensor_is_float_tensor():
106
+ for fn in tensor_overrides.FP16_FUNCS:
107
+ wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half,
108
+ handle, try_caching=True, verbose=verbose)
109
+ for fn in tensor_overrides.FP32_FUNCS:
110
+ wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float,
111
+ handle, try_caching=False, verbose=verbose)
112
+
113
+ # 2) Enable type-promotion on multi-arg functions and methods.
114
+ # NB: special handling for sequence fns (e.g. `torch.cat`).
115
+ promote_modules = [torch_overrides, tensor_overrides]
116
+ promote_table = [('CASTS', wrap.promote),
117
+ ('SEQUENCE_CASTS', wrap.sequence_promote)]
118
+ for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
119
+ promote_table):
120
+ for fn in getattr(promote_mod, list_name):
121
+ promote_fn(promote_mod.MODULE, fn, handle, verbose)
122
+
123
+ # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
124
+ if compat.tensor_is_float_tensor():
125
+ for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor,
126
+ torch.cuda.HalfTensor],
127
+ promote_table):
128
+ for fn in getattr(tensor_overrides, list_name):
129
+ promote_fn(cls, fn, handle, verbose)
130
+
131
+ # 3) For any in-place version of a blacklist function, error if any input is fp16.
132
+ # NB: this is overly conservative.
133
+ for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
134
+ wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
135
+
136
+ # 3.5) For any in-place blacklist method, error if called on fp16 tensor
137
+ for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
138
+ wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
139
+ if compat.tensor_is_float_tensor():
140
+ wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
141
+
142
+ # 4) For other in-place methods, match the type of self tensor
143
+ for fn in utils.as_inplace(itertools.chain(
144
+ tensor_overrides.FP16_FUNCS,
145
+ tensor_overrides.CASTS)):
146
+ wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
147
+ if compat.tensor_is_float_tensor():
148
+ wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
149
+ wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
150
+
151
+ # 5) RNNs + RNN cells are whitelisted specially
152
+ if rnn_compat.has_old_rnns():
153
+ wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
154
+ if not rnn_compat.has_old_rnns():
155
+ # Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.
156
+ torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
157
+ # Wrap all the rnns
158
+ for x in rnn_compat.RNN_NAMES:
159
+ wrap.new_rnn_cast(x.upper(), handle, verbose)
160
+
161
+ # Wrap all the RNN cells
162
+ rnn_compat.whitelist_rnn_cells(handle, verbose)
163
+
164
+ # 6) Place error+print message on banned functions.
165
+ # Or, if allow_banned, then cast to FP32.
166
+ for fn, err_msg in functional_overrides.BANNED_FUNCS:
167
+ if allow_banned:
168
+ wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
169
+ handle, try_caching=True, verbose=verbose)
170
+ else:
171
+ wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
172
+
173
+ _DECORATOR_HANDLE = handle
174
+
175
+ _amp_state.handle = handle
176
+
177
+ return handle
apex/apex/amp/compat.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # True for post-0.4, when Variables/Tensors merged.
4
+ def variable_is_tensor():
5
+ v = torch.autograd.Variable()
6
+ return isinstance(v, torch.Tensor)
7
+
8
+ def tensor_is_variable():
9
+ x = torch.Tensor()
10
+ return type(x) == torch.autograd.Variable
11
+
12
+ # False for post-0.4
13
+ def tensor_is_float_tensor():
14
+ x = torch.Tensor()
15
+ return type(x) == torch.FloatTensor
16
+
17
+ # Akin to `torch.is_tensor`, but returns True for Variable
18
+ # objects in pre-0.4.
19
+ def is_tensor_like(x):
20
+ return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable)
21
+
22
+ # Wraps `torch.is_floating_point` if present, otherwise checks
23
+ # the suffix of `x.type()`.
24
+ def is_floating_point(x):
25
+ if hasattr(torch, 'is_floating_point'):
26
+ return torch.is_floating_point(x)
27
+ try:
28
+ torch_type = x.type()
29
+ return torch_type.endswith('FloatTensor') or \
30
+ torch_type.endswith('HalfTensor') or \
31
+ torch_type.endswith('DoubleTensor')
32
+ except AttributeError:
33
+ return False
34
+
35
+ def scalar_python_val(x):
36
+ if hasattr(x, 'item'):
37
+ return x.item()
38
+ else:
39
+ if isinstance(x, torch.autograd.Variable):
40
+ return x.data[0]
41
+ else:
42
+ return x[0]
apex/apex/amp/frontend.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ._initialize import _initialize
3
+ from ._amp_state import _amp_state, warn_or_err, maybe_print
4
+
5
+
6
+ class Properties(object):
7
+ """
8
+ This class has two purposes: to establish a set of default properties,
9
+ and to route setting of these attributes through __setattr__ so that (in theory)
10
+ they can be checked for consistency with other existing args.
11
+ """
12
+ def __init__(self):
13
+ self.options = {
14
+ "enabled" : False,
15
+ "opt_level" : None,
16
+ "cast_model_type" : None,
17
+ "patch_torch_functions" : False,
18
+ "keep_batchnorm_fp32" : None,
19
+ "master_weights" : None,
20
+ "loss_scale" : 1.0,
21
+ # Reserved for future functionality
22
+ # "fused_optimizer" : False,
23
+ # "enable_ddp_interop" : False,
24
+ }
25
+
26
+ """
27
+ This function allows updating several options at a time without routing through
28
+ __setattr__ checks, to avoid "you can't get there from here" scenarios.
29
+ Currently not intended to be exposed; users are expected to select an opt_level
30
+ and apply consistent modifications.
31
+ """
32
+ def _update_options_dict(new_options):
33
+ for k, v in new_options:
34
+ if k in self.options:
35
+ self.options[k] = v
36
+ else:
37
+ raise ValueError("Tried to set unexpected option {}".format(k))
38
+ """
39
+ The members of "options" are not direct attributes of self, so access attempts
40
+ will roll down to __getattr__. This borrows from the logic in torch.nn.Module.
41
+ """
42
+ def __getattr__(self, name):
43
+ if "options" in self.__dict__:
44
+ options = self.__dict__["options"]
45
+ if name in options:
46
+ return options[name]
47
+ raise AttributeError("'{}' object has no attribute '{}'".format(
48
+ type(self).__name__, name))
49
+
50
+ def __setattr__(self, name, value):
51
+ if "options" in self.__dict__:
52
+ if name in self.options:
53
+ # print("setting {} {}".format(name, value))
54
+ if name == "cast_model_type":
55
+ if self.opt_level == "O1" and value is not None:
56
+ if value is not False:
57
+ if value is not torch.float32:
58
+ warn_or_err("O1 inserts casts around Torch functions rather than "
59
+ "model weights, so with O1, the model weights themselves "
60
+ "should remain FP32. If you wish to cast the model to a "
61
+ "different type, use opt_level='O2' or 'O3'. " +
62
+ "cast_model_type was {}".format(value))
63
+ self.options[name] = value
64
+ elif name == "patch_torch_functions":
65
+ if self.opt_level != "O1" and value:
66
+ warn_or_err("Currently, patch_torch_functions=True should only be set by "
67
+ "selecting opt_level='O1'.")
68
+ self.options[name] = value
69
+ elif name == "keep_batchnorm_fp32":
70
+ if self.opt_level == "O1" and value is not None:
71
+ warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
72
+ "to run in FP32, so keep_batchnorm_fp32 should be None." +
73
+ " keep_batchnorm_fp32 was {}".format(value))
74
+ if value == "False":
75
+ self.options[name] = False
76
+ elif value == "True":
77
+ self.options[name] = True
78
+ else:
79
+ assert (value is True or value is False or value is None),\
80
+ "keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
81
+ "or None, found keep_batchnorm_fp32={}".format(value)
82
+ self.options[name] = value
83
+ elif name == "master_weights":
84
+ if self.opt_level == "O1" and value is not None:
85
+ warn_or_err("It doesn't make sense to use master_weights with O1. "
86
+ "With O1, your model weights themselves should be FP32.")
87
+ self.options[name] = value
88
+ elif name == "loss_scale":
89
+ if value == "dynamic":
90
+ self.options[name] = value
91
+ else:
92
+ self.options[name] = float(value)
93
+ else:
94
+ self.options[name] = value
95
+ else:
96
+ super(Properties, self).__setattr__(name, value)
97
+
98
+
99
+ """ O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """
100
+
101
+ class O3:
102
+ brief = "O3: Pure FP16 training."
103
+ more = "Calls .half() on your model, converting the entire model to FP16.\n"\
104
+ "A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
105
+ "so you don't need to change your data pipeline.\n"\
106
+ "This mode is useful for establishing a performance ceiling.\n"\
107
+ "It's also possible training may 'just work' in this mode.\n"\
108
+ "If not, try other optimization levels."
109
+
110
+ def __call__(self, properties):
111
+ properties.enabled = True
112
+ properties.opt_level = "O3"
113
+ properties.cast_model_type = torch.float16
114
+ properties.patch_torch_functions = False
115
+ properties.keep_batchnorm_fp32 = False
116
+ properties.master_weights = False
117
+ properties.loss_scale = 1.0
118
+ # properties.fused_optimizer = False
119
+ # properties.enable_ddp_interop = False
120
+ return properties # modified in place so this isn't really necessary
121
+
122
+
123
+ class O2:
124
+ brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n"
125
+ more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
126
+ "to FP16. Batchnorms are retained in FP32 for additional stability.\n"\
127
+ "The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
128
+ "your data pipeline.\n"\
129
+ "O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
130
+ "these master weights, then copy the master weights into the FP16 model weights.\n"\
131
+ "Master weights can also improve convergence and stability."
132
+
133
+ def __call__(self, properties):
134
+ properties.enabled = True
135
+ properties.opt_level = "O2"
136
+ properties.cast_model_type = torch.float16
137
+ properties.patch_torch_functions = False
138
+ properties.keep_batchnorm_fp32 = True
139
+ properties.master_weights = True
140
+ properties.loss_scale = "dynamic"
141
+ # properties.fused_optimizer = False
142
+ # properties.enable_ddp_interop = False
143
+ return properties # modified in place so this isn't really necessary
144
+
145
+
146
+ class O1:
147
+ brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n"
148
+ more = "The type of your model's weights is not altered. However, internally,\n"\
149
+ "Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
150
+ "while operations that might benefit from the additional stability of FP32 are patched\n"\
151
+ "to cast their inputs to fp32.\n"\
152
+ "O1 is the safest way to try mixed precision training, and is recommended when\n"\
153
+ "trying mixed precision training for the first time."
154
+
155
+ def __call__(self, properties):
156
+ properties.enabled = True
157
+ properties.opt_level = "O1"
158
+ properties.cast_model_type = None
159
+ properties.patch_torch_functions = True
160
+ properties.keep_batchnorm_fp32 = None
161
+ properties.master_weights = None
162
+ properties.loss_scale = "dynamic"
163
+ # properties.fused_optimizer = False
164
+ # properties.enable_ddp_interop = False
165
+ return properties # modified in place so this isn't really necessary
166
+
167
+
168
+ class O0:
169
+ brief = "O0: Pure FP32 training.\n"
170
+ more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
171
+ "types of weights and internal Pytorch operations are not altered. This mode disables any\n"\
172
+ "FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
173
+
174
+ def __call__(self, properties):
175
+ properties.enabled = True
176
+ properties.opt_level = "O0"
177
+ properties.cast_model_type = torch.float32
178
+ properties.patch_torch_functions = False
179
+ properties.keep_batchnorm_fp32 = None
180
+ properties.master_weights = False
181
+ properties.loss_scale = 1.0
182
+ # properties.fused_optimizer = False
183
+ # properties.enable_ddp_interop = False
184
+ return properties # modified in place so this isn't really necessary
185
+
186
+
187
+ opt_levels = {"O3": O3(),
188
+ "O2": O2(),
189
+ "O1": O1(),
190
+ "O0": O0()}
191
+
192
+
193
+ # allow user to directly pass Properties struct as well?
194
+ def initialize(
195
+ models,
196
+ optimizers=None,
197
+ enabled=True,
198
+ opt_level="O1",
199
+ cast_model_type=None,
200
+ patch_torch_functions=None,
201
+ keep_batchnorm_fp32=None,
202
+ master_weights=None,
203
+ loss_scale=None,
204
+ cast_model_outputs=None,
205
+ num_losses=1,
206
+ verbosity=1,
207
+ min_loss_scale=None,
208
+ max_loss_scale=2.**24
209
+ ):
210
+ """
211
+ Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
212
+ chosen ``opt_level`` and overridden properties, if any.
213
+
214
+ ``amp.initialize`` should be called **after** you have finished
215
+ constructing your model(s) and
216
+ optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
217
+ See `Distributed training`_ in the Imagenet example.
218
+
219
+ Currently, ``amp.initialize`` should only be called **once**,
220
+ although it can process an arbitrary number of
221
+ models and optimizers (see the corresponding `Advanced Amp Usage topic`_).
222
+ If you think your use case requires ``amp.initialize`` to be called more than once,
223
+ `let us know`_.
224
+
225
+ Any property keyword argument that is not ``None`` will be interpreted as a manual override.
226
+
227
+ To prevent having to rewrite anything else in your script, name the returned models/optimizers
228
+ to replace the passed models/optimizers, as in the code sample below.
229
+
230
+ Args:
231
+ models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast.
232
+ optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast.
233
+ REQUIRED for training, optional for inference.
234
+ enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
235
+ should run as if Amp were not present.
236
+ opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
237
+ "O0", "O1", "O2", and "O3", explained in detail above.
238
+ cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
239
+ above.
240
+ patch_torch_functions (bool, optional, default=None): Optional property override.
241
+ keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If
242
+ passed as a string, must be the string "True" or "False".
243
+ master_weights (bool, optional, default=None): Optional property override.
244
+ loss_scale (float or str, optional, default=None): Optional property override. If passed as a string,
245
+ must be a string representing a number, e.g., "128.0", or the string "dynamic".
246
+ cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs
247
+ of your model(s) are always cast to a particular type regardless of ``opt_level``.
248
+ num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward
249
+ passes you plan to use. When used in conjunction with the ``loss_id`` argument to
250
+ ``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass,
251
+ which can improve stability. See "Multiple models/optimizers/losses"
252
+ under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still
253
+ support multiple losses/backward passes, but use a single global loss scale
254
+ for all of them.
255
+ verbosity (int, default=1): Set to 0 to suppress Amp-related output.
256
+ min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic
257
+ loss scaling. The default value of None means that no floor is imposed.
258
+ If dynamic loss scaling is not used, `min_loss_scale` is ignored.
259
+ max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by
260
+ dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored.
261
+
262
+ Returns:
263
+ Model(s) and optimizer(s) modified according to the ``opt_level``.
264
+ If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
265
+ also be a list.
266
+
267
+ Permissible invocations::
268
+
269
+ model, optim = amp.initialize(model, optim,...)
270
+ model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
271
+ [model1, model2], optim = amp.initialize([model1, model2], optim,...)
272
+ [model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)
273
+
274
+ # This is not an exhaustive list of the cross product of options that are possible,
275
+ # just a set of examples.
276
+ model, optim = amp.initialize(model, optim, opt_level="O0")
277
+ model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")
278
+
279
+ model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default
280
+ model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")
281
+
282
+ model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default
283
+ model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0")
284
+ model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")
285
+
286
+ model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default
287
+ model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0")
288
+ model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")
289
+
290
+ The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.
291
+
292
+ .. _`Distributed training`:
293
+ https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training
294
+
295
+ .. _`Imagenet example`:
296
+ https://github.com/NVIDIA/apex/tree/master/examples/imagenet
297
+
298
+ .. _`Advanced Amp Usage`:
299
+ https://nvidia.github.io/apex/advanced.html
300
+
301
+ .. _`Advanced Amp Usage topic`:
302
+ https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses
303
+
304
+ .. _`let us know`:
305
+ https://github.com/NVIDIA/apex/issues
306
+ """
307
+ _amp_state.opt_properties = Properties()
308
+ _amp_state.verbosity = verbosity
309
+
310
+ if not enabled:
311
+ if optimizers is None:
312
+ return models
313
+ else:
314
+ return models, optimizers
315
+
316
+ if not torch.backends.cudnn.enabled:
317
+ raise RuntimeError(
318
+ "Amp requires torch.backends.cudnn.enabled = True")
319
+
320
+ if opt_level not in opt_levels:
321
+ raise RuntimeError(
322
+ "Unexpected optimization level {}. ".format(opt_level) +
323
+ "Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
324
+ "not the number zero.")
325
+ else:
326
+ _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
327
+ maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
328
+ maybe_print("Defaults for this optimization level are:", True)
329
+ for k, v in _amp_state.opt_properties.options.items():
330
+ maybe_print("{:22} : {}".format(k, v), True)
331
+
332
+ _amp_state.min_loss_scale = min_loss_scale
333
+ _amp_state.max_loss_scale = max_loss_scale
334
+
335
+ maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
336
+ # I chose to have the keyword arguments listed directly in the argument list,
337
+ # instead of **kwargs, so I can't use kwargs.items() here.
338
+ if enabled is not None:
339
+ _amp_state.opt_properties.enabled = enabled
340
+ if opt_level is not None:
341
+ _amp_state.opt_properties.opt_level = opt_level
342
+ if cast_model_type is not None:
343
+ _amp_state.opt_properties.cast_model_type = cast_model_type
344
+ if patch_torch_functions is not None:
345
+ _amp_state.opt_properties.patch_torch_functions = patch_torch_functions
346
+ if keep_batchnorm_fp32 is not None:
347
+ _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
348
+ if master_weights is not None:
349
+ _amp_state.opt_properties.master_weights = master_weights
350
+ if loss_scale is not None:
351
+ _amp_state.opt_properties.loss_scale = loss_scale
352
+
353
+ maybe_print("After processing overrides, optimization options are:", True)
354
+ for k, v in _amp_state.opt_properties.options.items():
355
+ maybe_print("{:22} : {}".format(k, v), True)
356
+
357
+ return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs)
358
+
359
+
360
+ # TODO: is this necessary/useful?
361
+ # def check_option_consistency(enabled=True,
362
+ # opt_level=None,
363
+ # cast_model_type=None,
364
+ # patch_torch_functions=None,
365
+ # keep_batchnorm_fp32=None,
366
+ # master_weights=None,
367
+ # loss_scale=None,
368
+ # enable_ddp_interop=None,
369
+ # hard_override=False):
370
+ # """
371
+ # Utility function that enables users to quickly check if the option combination they intend
372
+ # to use is permitted. ``check_option_consistency`` does not require models or optimizers
373
+ # to be constructed, and can be called at any point in the script. ``check_option_consistency``
374
+ # is totally self-contained; it does not set any amp global state or affect anything outside
375
+ # of itself.
376
+ # """
377
+ #
378
+ # if not enabled:
379
+ # return
380
+ #
381
+ # if opt_level not in opt_levels:
382
+ # raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.")
383
+ # else:
384
+ # opt_properties = opt_levels[opt_level](Properties())
385
+ # print("Selected optimization level {}", opt_levels[opt_level].brief)
386
+ # print("Defaults for this optimization level are:")
387
+ # for k, v in opt_properties.options:
388
+ # print("{:22} : {}".format(k, v))
389
+ #
390
+ # print("Processing user overrides (additional kwargs that are not None)...")
391
+ # for k, v in kwargs:
392
+ # if k not in _amp_state.opt_properties.options:
393
+ # raise RuntimeError("Unexpected kwarg {}".format(k))
394
+ # if v is not None:
395
+ # setattr(opt_properties, k, v)
396
+ #
397
+ # print("After processing overrides, optimization options are:")
398
+ # for k, v in opt_properties.options:
399
+ # print("{:22} : {}".format(k, v))
apex/apex/amp/handle.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+ import torch
4
+
5
+ from . import utils
6
+ from .opt import OptimWrapper
7
+ from .scaler import LossScaler
8
+ from ._amp_state import _amp_state, master_params, maybe_print
9
+ from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
10
+ from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
11
+ from ..parallel.LARC import LARC
12
+
13
+
14
+ # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
15
+ @contextlib.contextmanager
16
+ def scale_loss(loss,
17
+ optimizers,
18
+ loss_id=0,
19
+ model=None,
20
+ delay_unscale=False,
21
+ delay_overflow_check=False):
22
+ """
23
+ On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
24
+ ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
25
+
26
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
27
+ scaled_loss.backward()
28
+
29
+ On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs
30
+ and unscaled, so that ``optimizer.step()`` can be called.
31
+
32
+ .. note::
33
+ If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and
34
+ can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``)
35
+ any FP16 gradients are copied to FP32 master gradients before being unscaled.
36
+ ``optimizer.step()`` will then apply the unscaled master gradients to the master params.
37
+
38
+ .. warning::
39
+ If Amp is using explicit FP32 master params, only the FP32 master gradients will be
40
+ unscaled. The direct ``.grad`` attributes of any FP16
41
+ model params will remain scaled after context manager exit.
42
+ This subtlety affects gradient clipping. See "Gradient clipping" under
43
+ `Advanced Amp Usage`_ for best practices.
44
+
45
+ Args:
46
+ loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context
47
+ manager yields is simply ``loss.float()*loss_scale``, so in principle
48
+ ``loss`` could have more than one element, as long as you call
49
+ ``backward()`` on ``scaled_loss`` appropriately within the context manager body.
50
+ optimizers: All optimizer(s) for which the current backward pass is creating gradients.
51
+ Must be an optimizer or list of optimizers returned from an earlier call
52
+ to ``amp.initialize``. For example use with multiple optimizers, see
53
+ "Multiple models/optimizers/losses" under `Advanced Amp Usage`_.
54
+ loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument
55
+ to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id``
56
+ must be an integer between 0 and ``num_losses`` that tells Amp which loss is
57
+ being used for the current backward pass. See "Multiple models/optimizers/losses"
58
+ under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp
59
+ will use the default global loss scaler for this backward pass.
60
+ model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future
61
+ optimizations.
62
+ delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and
63
+ the default value of ``False`` is strongly recommended.
64
+ If ``True``, Amp will not unscale the gradients or perform model->master
65
+ gradient copies on context manager exit.
66
+ ``delay_unscale=True`` is a minor ninja performance optimization and can result
67
+ in weird gotchas (especially with multiple models/optimizers/losses),
68
+ so only use it if you know what you're doing.
69
+ "Gradient accumulation across iterations" under `Advanced Amp Usage`_
70
+ illustrates a situation where this CAN (but does not need to) be used.
71
+
72
+ .. warning::
73
+ If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be
74
+ called yet after context manager exit, and must wait for another, later backward context
75
+ manager invocation with ``delay_unscale`` left to False.
76
+
77
+ .. _`Advanced Amp Usage`:
78
+ https://nvidia.github.io/apex/advanced.html
79
+ """
80
+ if not hasattr(_amp_state, "opt_properties"):
81
+ raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. "
82
+ "model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called "
83
+ "before `with amp.scale_loss`.")
84
+
85
+ if not _amp_state.opt_properties.enabled:
86
+ yield loss
87
+ return
88
+
89
+ if isinstance(optimizers, torch.optim.Optimizer) or isinstance(optimizers, LARC):
90
+ optimizers = [optimizers]
91
+
92
+ # this is what happens when i have to support tools from different sources under the same API...
93
+ # TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
94
+ if isinstance(optimizers, FP16_Optimizer_for_fused):
95
+ loss_scale = optimizers.cur_scale
96
+ else:
97
+ loss_scaler = _amp_state.loss_scalers[loss_id]
98
+ loss_scale = loss_scaler.loss_scale()
99
+
100
+ if ((not _amp_state.opt_properties.master_weights)
101
+ and (not loss_scaler.dynamic)
102
+ and loss_scale == 1.0):
103
+ yield loss.float()
104
+ # Needing to drop the cache here as well is an ugly gotcha.
105
+ # But for now I think it's necessary to short-circuit.
106
+ # Probably ok to skip this if not delay_unscale
107
+ if _amp_state.opt_properties.patch_torch_functions:
108
+ _amp_state.handle._clear_cache()
109
+ return
110
+
111
+ if not delay_unscale:
112
+ if isinstance(optimizers, list):
113
+ for optimizer in optimizers:
114
+ if not optimizer._amp_stash.params_have_scaled_gradients:
115
+ optimizer._prepare_amp_backward()
116
+
117
+ yield (loss.float())*loss_scale
118
+
119
+ if delay_unscale:
120
+ for optimizer in optimizers:
121
+ optimizer._amp_stash.params_have_scaled_gradients = True
122
+ else:
123
+ # FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
124
+ if not isinstance(optimizers, FP16_Optimizer_for_fused):
125
+ loss_scaler.clear_overflow_state()
126
+ for optimizer in optimizers:
127
+ optimizer._post_amp_backward(loss_scaler)
128
+ optimizer._amp_stash.params_have_scaled_gradients = False
129
+ # For future fused optimizers that enable sync-free dynamic loss scaling,
130
+ # should_skip will always be False.
131
+ should_skip = False if delay_overflow_check else loss_scaler.update_scale()
132
+ if should_skip:
133
+ for optimizer in optimizers:
134
+ if not optimizer._amp_stash.already_patched:
135
+ # Close on loss_scaler and loss_id as well, to be safe. Probably not
136
+ # necessary because amp.scale_loss is already creating a temporary scope.
137
+ def patch_step(opt, loss_scaler, loss_id):
138
+ opt_step = opt.step
139
+ def skip_step(closure=None):
140
+ if closure is not None:
141
+ raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
142
+ maybe_print(("Gradient overflow. Skipping step, loss scaler " +
143
+ "{} reducing loss scale to {}").format(loss_id,
144
+ loss_scaler.loss_scale()))
145
+ if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"):
146
+ # Clear the master grads that wouldn't be zeroed by model.zero_grad()
147
+ for param in opt._amp_stash.all_fp32_from_fp16_params:
148
+ param.grad = None
149
+ opt.step = opt_step
150
+ opt._amp_stash.already_patched = False
151
+ return skip_step
152
+ optimizer.step = patch_step(optimizer, loss_scaler, loss_id)
153
+ optimizer._amp_stash.already_patched = True
154
+
155
+ # Probably ok to skip this if not delay_unscale
156
+ if _amp_state.opt_properties.patch_torch_functions:
157
+ _amp_state.handle._clear_cache()
158
+
159
+
160
+ # Free function version of AmpHandle.disable_casts, another step on the
161
+ # path to removing the concept of "AmpHandle"
162
+ @contextlib.contextmanager
163
+ def disable_casts():
164
+ _amp_state.handle._is_active = False
165
+ yield
166
+ _amp_state.handle._is_active = True
167
+
168
+
169
+ class AmpHandle(object):
170
+ def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False):
171
+ self._enable_caching = enable_caching
172
+ self._verbose = verbose
173
+ self._cache = dict()
174
+ self._default_scaler = LossScaler(loss_scale)
175
+ self._is_active = True
176
+ self._all_wrappers = []
177
+
178
+ def is_active(self):
179
+ return self._is_active
180
+
181
+ @contextlib.contextmanager
182
+ def _disable_casts(self):
183
+ self._is_active = False
184
+ yield
185
+ self._is_active = True
186
+
187
+ def wrap_optimizer(self, optimizer, num_loss=1):
188
+ self._default_scaler = None
189
+ return OptimWrapper(optimizer, self, num_loss)
190
+
191
+ @contextlib.contextmanager
192
+ def scale_loss(self, loss, optimizer):
193
+ raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, "
194
+ "documented here: https://nvidia.github.io/apex/amp.html. Transition guide: "
195
+ "https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users")
196
+
197
+ if not self.is_active():
198
+ yield loss
199
+ return
200
+
201
+ if self._default_scaler is None:
202
+ raise RuntimeError(
203
+ 'After calling `handle.wrap_optimizer()`, you must explicitly ' +
204
+ 'use `optimizer.scale_loss(loss)`.')
205
+
206
+ # TODO: this code block is duplicated here and `opt.py`. Unify.
207
+ loss_scale = self._default_scaler.loss_scale()
208
+ yield loss * loss_scale
209
+
210
+ self._default_scaler.clear_overflow_state()
211
+ self._default_scaler.unscale(
212
+ master_params(optimizer),
213
+ master_params(optimizer),
214
+ loss_scale)
215
+ should_skip = self._default_scaler.update_scale()
216
+ if should_skip:
217
+ optimizer_step = optimizer.step
218
+ def skip_step():
219
+ maybe_print('Gradient overflow, skipping update')
220
+ optimizer.step = optimizer_step
221
+ optimizer.step = skip_step
222
+
223
+ self._clear_cache()
224
+
225
+ def _clear_cache(self):
226
+ self._cache.clear()
227
+
228
+ # Experimental support for saving / restoring uncasted versions of functions
229
+ def _save_func(self, mod, fn, func):
230
+ self._all_wrappers.append((mod, fn, func))
231
+
232
+ def _deactivate(self):
233
+ for mod, fn, func in self._all_wrappers:
234
+ utils.set_func(mod, fn, func)
235
+ self._all_wrappers = []
236
+
237
+ @property
238
+ def has_cache(self):
239
+ return self._enable_caching
240
+
241
+ @property
242
+ def cache(self):
243
+ return self._cache
244
+
245
+ def remove_cache(self, param):
246
+ if self.has_cache and param in self.cache:
247
+ del self.cache[param]
248
+
249
+ @property
250
+ def verbose(self):
251
+ return self._verbose
252
+
253
+ class NoOpHandle(object):
254
+ def is_active(self):
255
+ return False
256
+
257
+ @contextlib.contextmanager
258
+ def _disable_casts(self):
259
+ yield
260
+
261
+ def wrap_optimizer(self, optimizer, num_loss=1):
262
+ return OptimWrapper(optimizer, self, num_loss)
263
+
264
+ @contextlib.contextmanager
265
+ def scale_loss(self, loss, optimizer):
266
+ yield loss
267
+
268
+ @property
269
+ def has_cache(self):
270
+ return False
271
+
272
+ @property
273
+ def verbose(self):
274
+ return False
275
+
276
+ def _clear_cache(self):
277
+ pass
278
+
279
+ def _deactivate(self):
280
+ pass
apex/apex/amp/lists/__init__.py ADDED
File without changes
apex/apex/amp/lists/functional_overrides.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # TODO: think about the following two. They do weird things.
3
+ # - torch.nn.utils.clip_grad (but it should always be fp32 anyway)
4
+ # - torch.nn.utils.weight_norm
5
+
6
+ # Notes:
7
+ # F.instance_norm uses batch_norm internally. Which correctly handles
8
+ # fp16 in/out with fp32 weights. So we shouldn't do anything for
9
+ # either of these.
10
+ # F.normalize calls `input.norm()` internally, so it's redundant, but
11
+ # kept here in case impl. changes.
12
+ # F.cosine_similarity is same: calls `x.norm()` internally.
13
+
14
+ import torch.nn.functional
15
+
16
+ MODULE = torch.nn.functional
17
+
18
+ FP16_FUNCS = [
19
+ 'conv1d',
20
+ 'conv2d',
21
+ 'conv3d',
22
+ 'conv_transpose1d',
23
+ 'conv_transpose2d',
24
+ 'conv_transpose3d',
25
+ 'conv_tbc', # Undocumented / maybe new?
26
+ 'linear',
27
+ ]
28
+
29
+ FP32_FUNCS = [
30
+
31
+ # Interpolation/Upsampling
32
+ 'interpolate',
33
+
34
+ # Pointwise
35
+ 'softplus',
36
+ 'softmin',
37
+ 'log_softmax',
38
+ 'softmax',
39
+
40
+ # Normalization
41
+ 'layer_norm',
42
+ 'group_norm',
43
+ 'local_response_norm',
44
+ 'normalize',
45
+ 'cosine_similarity',
46
+
47
+ # Loss functions
48
+ # TODO: which of these can be fp16?
49
+ 'poisson_nll_loss',
50
+ 'cosine_embedding_loss',
51
+ 'cross_entropy',
52
+ 'hinge_embedding_loss',
53
+ 'kl_div',
54
+ 'l1_loss',
55
+ 'mse_loss',
56
+ 'margin_ranking_loss',
57
+ 'multilabel_margin_loss',
58
+ 'multilabel_soft_margin_loss',
59
+ 'multi_margin_loss',
60
+ 'nll_loss',
61
+ 'binary_cross_entropy_with_logits',
62
+ 'smooth_l1_loss',
63
+ 'soft_margin_loss',
64
+ 'triplet_margin_loss'
65
+ ]
66
+
67
+ BANNED_FUNCS = [
68
+ ('binary_cross_entropy',
69
+ ("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` "
70
+ "It requires that the output of the previous function be already a FloatTensor. \n\n"
71
+ "Most models have a Sigmoid right before BCELoss. In that case, you can use\n"
72
+ " torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer "
73
+ "that is compatible with amp.\nAnother option is to add\n"
74
+ " amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n"
75
+ "If you _really_ know what you are doing, you can disable this warning by passing "
76
+ "allow_banned=True to `amp.init()`."))
77
+ ]
apex/apex/amp/lists/tensor_overrides.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .. import compat
2
+ from . import torch_overrides
3
+
4
+ import importlib
5
+
6
+ import torch
7
+
8
+ if compat.variable_is_tensor() and not compat.tensor_is_variable():
9
+ MODULE = torch.Tensor
10
+ else:
11
+ MODULE = torch.autograd.Variable
12
+
13
+
14
+ FP16_FUNCS = [
15
+ '__matmul__',
16
+ ]
17
+
18
+ FP32_FUNCS = [
19
+ '__ipow__',
20
+ '__pow__',
21
+ '__rpow__',
22
+
23
+ # Cast to fp32 before transfer to CPU
24
+ 'cpu',
25
+ ]
26
+
27
+ CASTS = [
28
+ '__add__',
29
+ '__div__',
30
+ '__eq__',
31
+ '__ge__',
32
+ '__gt__',
33
+ '__iadd__',
34
+ '__idiv__',
35
+ '__imul__',
36
+ '__isub__',
37
+ '__itruediv__',
38
+ '__le__',
39
+ '__lt__',
40
+ '__mul__',
41
+ '__ne__',
42
+ '__radd__',
43
+ '__rdiv__',
44
+ '__rmul__',
45
+ '__rsub__',
46
+ '__rtruediv__',
47
+ '__sub__',
48
+ '__truediv__',
49
+ ]
50
+
51
+ # None of these, but here to make code cleaner.
52
+ SEQUENCE_CASTS = []
53
+
54
+ # We need to grab all the methods from torch_overrides and add them to
55
+ # the Tensor lists as well, as almost all methods are duplicated
56
+ # between `torch` and `torch.Tensor` (and check with `hasattr`,
57
+ # because a few random ones aren't defined on Tensor)
58
+ _self_mod = importlib.import_module(__name__)
59
+ for attrname in ['FP16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']:
60
+ lst = getattr(_self_mod, attrname)
61
+ for fn in getattr(torch_overrides, attrname):
62
+ if hasattr(MODULE, fn):
63
+ lst.append(fn)
apex/apex/amp/lists/torch_overrides.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .. import utils
4
+
5
+ MODULE = torch
6
+
7
+ FP16_FUNCS = [
8
+ # Low level functions wrapped by torch.nn layers.
9
+ # The wrapper layers contain the weights which are then passed in as a parameter
10
+ # to these functions.
11
+ 'conv1d',
12
+ 'conv2d',
13
+ 'conv3d',
14
+ 'conv_transpose1d',
15
+ 'conv_transpose2d',
16
+ 'conv_transpose3d',
17
+ 'conv_tbc',
18
+ 'prelu',
19
+
20
+ # BLAS
21
+ 'addmm',
22
+ 'addmv',
23
+ 'addr',
24
+ 'matmul',
25
+ 'mm',
26
+ 'mv',
27
+ ]
28
+
29
+ FP32_FUNCS = [
30
+ # Pointwise
31
+ 'acos',
32
+ 'asin',
33
+ 'cosh',
34
+ 'erfinv',
35
+ 'exp',
36
+ 'expm1',
37
+ 'log',
38
+ 'log10',
39
+ 'log2',
40
+ 'reciprocal',
41
+ 'rsqrt',
42
+ 'sinh',
43
+ 'tan',
44
+
45
+ # Other math
46
+ 'pow',
47
+
48
+ # Reduction
49
+ 'cumprod',
50
+ 'cumsum',
51
+ 'dist',
52
+ 'mean',
53
+ 'norm',
54
+ 'prod',
55
+ 'std',
56
+ 'sum',
57
+ 'var',
58
+
59
+ # Misc
60
+ 'renorm'
61
+ ]
62
+
63
+ # Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
64
+ # check the CUDA version -- if at least 9.1, then put the bmm
65
+ # functions on the fp16 list. Otherwise, put them on the fp32 list.
66
+ _bmms = ['addbmm',
67
+ 'baddbmm',
68
+ 'bmm']
69
+ if utils.get_cuda_version() >= (9, 1, 0):
70
+ FP16_FUNCS.extend(_bmms)
71
+ else:
72
+ FP32_FUNCS.extend(_bmms)
73
+
74
+ # Multi-tensor fns that may need type promotion
75
+ CASTS = [
76
+ # Multi-tensor math
77
+ 'addcdiv',
78
+ 'addcmul',
79
+ 'atan2',
80
+ 'cross',
81
+ 'bilinear',
82
+
83
+ # Element-wise _or_ tensor-wise math
84
+ 'add',
85
+ 'div',
86
+ 'mul',
87
+
88
+ # Comparison
89
+ 'eq',
90
+ 'equal',
91
+ 'ge',
92
+ 'gt',
93
+ 'le',
94
+ 'lt',
95
+ 'ne'
96
+ ]
97
+
98
+ # Functions that take sequence arguments. We need to inspect the whole
99
+ # sequence and cast to the widest type.
100
+ SEQUENCE_CASTS = [
101
+ 'cat',
102
+ 'stack'
103
+ ]
apex/apex/amp/opt.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ from .scaler import LossScaler, master_params
5
+ from ._amp_state import maybe_print
6
+
7
+ import numpy as np
8
+
9
+ class OptimWrapper(object):
10
+ def __init__(self, optimizer, amp_handle, num_loss):
11
+ self._optimizer = optimizer
12
+ self._amp_handle = amp_handle
13
+ self._num_loss = num_loss
14
+ self._loss_idx = 0
15
+ self._skip_next = [False] * num_loss
16
+ self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)]
17
+
18
+ @contextlib.contextmanager
19
+ def scale_loss(self, loss):
20
+ if not self._amp_handle.is_active():
21
+ yield loss
22
+ return
23
+
24
+ # When there are multiple losses per-optimizer, we need
25
+ # to save out current grad accumulation, since we won't be
26
+ # able to unscale this particulare loss once the grads are
27
+ # all mixed together.
28
+ cached_grads = []
29
+ if self._loss_idx > 0:
30
+ for p in master_params(self._optimizer):
31
+ if p.grad is not None:
32
+ cached_grads.append(p.grad.data.detach().clone())
33
+ else:
34
+ cached_grads.append(None)
35
+ self._optimizer.zero_grad()
36
+
37
+ loss_scale = self._cur_loss_scaler().loss_scale()
38
+ yield loss * loss_scale
39
+
40
+ self._cur_loss_scaler().clear_overflow_state()
41
+ self._cur_loss_scaler().unscale(
42
+ master_params(self._optimizer),
43
+ master_params(self._optimizer),
44
+ loss_scale)
45
+ self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale()
46
+ self._loss_idx += 1
47
+
48
+ if len(cached_grads) > 0:
49
+ for p, cached_grad in zip(master_params(self._optimizer),
50
+ cached_grads):
51
+ if cached_grad is not None:
52
+ p.grad.data.add_(cached_grad)
53
+ cached_grads = []
54
+
55
+ def _cur_loss_scaler(self):
56
+ assert 0 <= self._loss_idx < self._num_loss
57
+ return self._loss_scaler[self._loss_idx]
58
+
59
+ def step(self, closure=None):
60
+ if not self._amp_handle.is_active():
61
+ return self._optimizer.step(closure=closure)
62
+
63
+ self._loss_idx = 0
64
+
65
+ for group in self._optimizer.param_groups:
66
+ for p in group['params']:
67
+ self._amp_handle.remove_cache(p)
68
+
69
+ if closure is not None:
70
+ raise NotImplementedError(
71
+ 'The `closure` argument is unsupported by the amp ' +
72
+ 'optimizer wrapper.')
73
+ if any(self._skip_next):
74
+ maybe_print('Gradient overflow, skipping update')
75
+ self._skip_next = [False] * self._num_loss
76
+ else:
77
+ return self._optimizer.step(closure=closure)
78
+
79
+ # Forward any attribute lookups
80
+ def __getattr__(self, attr):
81
+ return getattr(self._optimizer, attr)
82
+
83
+ # Forward all torch.optim.Optimizer methods
84
+ def __getstate__(self):
85
+ return self._optimizer.__getstate__()
86
+
87
+ def __setstate__(self):
88
+ return self._optimizer.__setstate__()
89
+
90
+ def __repr__(self):
91
+ return self._optimizer.__repr__()
92
+
93
+ def state_dict(self):
94
+ return self._optimizer.state_dict()
95
+
96
+ def load_state_dict(self, state_dict):
97
+ return self._optimizer.load_state_dict(state_dict)
98
+
99
+ def zero_grad(self):
100
+ return self._optimizer.zero_grad()
101
+
102
+ def add_param_group(self, param_group):
103
+ return self._optimizer.add_param_group(param_group)
apex/apex/amp/rnn_compat.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import utils, wrap
2
+
3
+ import torch
4
+ _VF = torch._C._VariableFunctions
5
+ RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
6
+
7
+ def _gen_VF_wrapper(name):
8
+ def wrapper(*args, **kwargs):
9
+ return getattr(_VF, name)(*args, **kwargs)
10
+ return wrapper
11
+
12
+ # Some python magic to generate an object that has the rnn cell functions
13
+ # defined on it, all of which call into corresponding _VF version.
14
+ # Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF"
15
+ # imported at module scope within torch.nn.modules.rnn). This should
16
+ # not affect third-party importers of _VF.py.
17
+ class VariableFunctionsShim(object):
18
+ def __init__(self):
19
+ for name in RNN_NAMES:
20
+ for suffix in ['', '_cell']:
21
+ fn_name = name + suffix
22
+ setattr(self, fn_name, _gen_VF_wrapper(fn_name))
23
+
24
+ def has_old_rnns():
25
+ try:
26
+ torch.nn.backends.thnn.backend.LSTMCell
27
+ return True
28
+ except:
29
+ return False
30
+
31
+ def whitelist_rnn_cells(handle, verbose):
32
+ # Different module + function names in old/new RNN cases
33
+ if has_old_rnns():
34
+ fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
35
+ mod = torch.nn.backends.thnn.backend
36
+ else:
37
+ fn_names = [x + '_cell' for x in RNN_NAMES]
38
+ mod = torch.nn.modules.rnn._VF
39
+ assert isinstance(mod, VariableFunctionsShim)
40
+
41
+ # Insert casts on cell functions
42
+ for fn in fn_names:
43
+ wrap.cached_cast(mod, fn, utils.maybe_half, handle,
44
+ try_caching=True, verbose=verbose)
45
+
46
+ if has_old_rnns():
47
+ # Special handling of `backward` for fused gru / lstm:
48
+ # The `backward` method calls Tensor.sum() (blacklist) internally,
49
+ # and then the resulting grad_input has the wrong type.
50
+ # TODO: where else is this a problem?
51
+ for rnn_type in ['GRUFused', 'LSTMFused']:
52
+ mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
53
+ wrap.disable_casts(mod, 'backward', handle)
apex/apex/amp/scaler.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..multi_tensor_apply import multi_tensor_applier
3
+ from ._amp_state import _amp_state, master_params, maybe_print
4
+ from itertools import product
5
+
6
+ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
7
+ # Exception handling for 18.04 compatibility
8
+ if check_overflow:
9
+ cpu_sum = float(model_grad.float().sum())
10
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
11
+ return True
12
+
13
+ if master_grad is not model_grad: # copy_ probably internally short-circuits this
14
+ master_grad.copy_(model_grad)
15
+ if scale != 1.0:
16
+ master_grad.mul_(scale)
17
+ return False
18
+
19
+ def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, scale, check_overflow=False):
20
+ # Exception handling for 18.04 compatibility
21
+ if check_overflow:
22
+ cpu_sum = float(model_grad.float().sum())
23
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
24
+ return True
25
+
26
+ # if master_grad is not model_grad: # copy_ probably internally short-circuits this
27
+ # master_grad.copy_(model_grad)
28
+ assert stashed_grad.dtype == master_grad.dtype
29
+ converted_model_grad = model_grad.to(master_grad.dtype)
30
+ stashed_grad.add_(scale, converted_model_grad)
31
+ master_grad.data = stashed_grad.data
32
+ return False
33
+
34
+ class LossScaler(object):
35
+ warned_no_fused_kernel = False
36
+ warned_unscaling_non_fp32_grad = False
37
+ has_fused_kernel = False
38
+
39
+ def __init__(self,
40
+ loss_scale,
41
+ init_scale=2.**16,
42
+ scale_factor=2.,
43
+ scale_window=2000,
44
+ min_loss_scale=None,
45
+ max_loss_scale=2.**24):
46
+ if loss_scale == "dynamic":
47
+ self.dynamic = True
48
+ self._loss_scale = init_scale
49
+ else:
50
+ self.dynamic = False
51
+ self._loss_scale = loss_scale
52
+ self._max_loss_scale = max_loss_scale
53
+ self._min_loss_scale = min_loss_scale
54
+ self._scale_seq_len = scale_window
55
+ self._unskipped = 0
56
+ self._has_overflow = False
57
+ self._overflow_buf = torch.cuda.IntTensor([0])
58
+ if multi_tensor_applier.available:
59
+ import amp_C
60
+ LossScaler.has_fused_kernel = multi_tensor_applier.available
61
+ LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
62
+ LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
63
+ else:
64
+ if not LossScaler.warned_no_fused_kernel:
65
+ maybe_print(
66
+ "Warning: multi_tensor_applier fused unscale kernel is unavailable, "
67
+ "possibly because apex was installed without --cuda_ext --cpp_ext. "
68
+ "Using Python fallback. Original ImportError was: " +
69
+ repr(multi_tensor_applier.import_err),
70
+ True)
71
+ LossScaler.has_fused_kernel = False
72
+ LossScaler.warned_no_fused_kernel = True
73
+
74
+ def loss_scale(self):
75
+ return self._loss_scale
76
+
77
+ def unscale_python(self, model_grads, master_grads, scale):
78
+ for model, master in zip(model_grads, master_grads):
79
+ if model is not None:
80
+ if not LossScaler.warned_unscaling_non_fp32_grad:
81
+ if master.dtype != torch.float32:
82
+ maybe_print(
83
+ "Attempting to unscale a grad with type {} ".format(master.type()) +
84
+ "Unscaling non-fp32 grads may indicate an error. "
85
+ "When using Amp, you don't need to call .half() on your model.")
86
+ LossScaler.warned_unscaling_non_fp32_grad = True
87
+ self._has_overflow = scale_check_overflow_python(model,
88
+ master,
89
+ 1./scale,
90
+ self.dynamic)
91
+ if self._has_overflow and self.dynamic:
92
+ break
93
+
94
+ # unused_scale keeps some of the old API alive for hopefully a short time.
95
+ def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False):
96
+ if self._has_overflow:
97
+ return
98
+
99
+ scale = self._loss_scale
100
+
101
+ if scale == 1.0 and models_are_masters and not self.dynamic:
102
+ return
103
+
104
+ if LossScaler.has_fused_kernel:
105
+ # if (not LossScaler.warned_unscaling_non_fp32_grad
106
+ # and master_grads[0].dtype == torch.float16):
107
+ # print("Warning: unscaling grads that are not FP32. "
108
+ # "Unscaling non-fp32 grads may indicate an error. "
109
+ # "When using Amp, you don't need to call .half() on your model.")
110
+ # # Setting this to True unconditionally allows the possibility of an escape
111
+ # # if never-before-seen non-fp32 grads are created in some later iteration.
112
+ # LossScaler.warned_unscaling_non_fp32_grad = True
113
+ multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
114
+ self._overflow_buf,
115
+ [model_grads, master_grads],
116
+ 1./scale)
117
+ else:
118
+ self.unscale_python(model_grads, master_grads, scale)
119
+
120
+ # Defer to update_scale
121
+ # If the fused kernel is available, we only need one D2H memcopy and sync.
122
+ # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
123
+ # self._has_overflow = self._overflow_buf.item()
124
+
125
+ def unscale_with_stashed_python(self,
126
+ model_grads,
127
+ stashed_master_grads,
128
+ master_grads,
129
+ scale):
130
+ for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
131
+ if model is None and stashed is None:
132
+ continue
133
+ else:
134
+ if not LossScaler.warned_unscaling_non_fp32_grad:
135
+ if master.dtype != torch.float32:
136
+ maybe_print(
137
+ "Attempting to unscale a grad with type {} ".format(master.type()) +
138
+ "Unscaling non-fp32 grads may indicate an error. "
139
+ "When using Amp, you don't need to call .half() on your model.")
140
+ LossScaler.warned_unscaling_non_fp32_grad = True
141
+ self._has_overflow = axpby_check_overflow_python(model,
142
+ stashed,
143
+ master,
144
+ 1./scale,
145
+ self.dynamic)
146
+ if self._has_overflow and self.dynamic:
147
+ break
148
+
149
+ def unscale_with_stashed(self,
150
+ model_grads,
151
+ stashed_master_grads,
152
+ master_grads):
153
+ if self._has_overflow:
154
+ return
155
+
156
+ scale = self._loss_scale
157
+
158
+ if LossScaler.has_fused_kernel:
159
+ if (not LossScaler.warned_unscaling_non_fp32_grad
160
+ and master_grads[0].dtype == torch.float16):
161
+ print("Warning: unscaling grads that are not FP32. "
162
+ "Unscaling non-fp32 grads may indicate an error. "
163
+ "When using Amp, you don't need to call .half() on your model.")
164
+ # Setting this to True unconditionally allows the possibility of an escape
165
+ # if never-before-seen non-fp32 grads are created in some later iteration.
166
+ LossScaler.warned_unscaling_non_fp32_grad = True
167
+ multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
168
+ self._overflow_buf,
169
+ [model_grads, stashed_master_grads, master_grads],
170
+ 1./scale,
171
+ 1.0,
172
+ 0) # check only arg 0, aka the incoming model grads, for infs
173
+ else:
174
+ self.unscale_with_stashed_python(model_grads,
175
+ stashed_master_grads,
176
+ master_grads,
177
+ scale)
178
+
179
+ # Defer to update_scale
180
+ # If the fused kernel is available, we only need one D2H memcopy and sync.
181
+ # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
182
+ # self._has_overflow = self._overflow_buf.item()
183
+
184
+ def clear_overflow_state(self):
185
+ self._has_overflow = False
186
+ if self.has_fused_kernel:
187
+ self._overflow_buf.zero_()
188
+
189
+ # Separate so unscale() can be called more that once before updating.
190
+ def update_scale(self):
191
+ # If the fused kernel is available, we only need one D2H memcopy and sync.
192
+ if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
193
+ self._has_overflow = self._overflow_buf.item()
194
+
195
+ if self._has_overflow and self.dynamic:
196
+ should_skip = True
197
+ if(self._min_loss_scale):
198
+ self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
199
+ else:
200
+ self._loss_scale = self._loss_scale/2.
201
+ self._unskipped = 0
202
+ else:
203
+ should_skip = False
204
+ self._unskipped += 1
205
+
206
+ if self._unskipped == self._scale_seq_len and self.dynamic:
207
+ self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
208
+ self._unskipped = 0
209
+
210
+ return should_skip
apex/apex/amp/utils.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import compat
2
+
3
+ import functools
4
+ import itertools
5
+
6
+ import torch
7
+
8
+ def get_cuda_version():
9
+ return tuple(int(x) for x in torch.version.cuda.split('.'))
10
+
11
+ def is_fp_tensor(x):
12
+ if is_nested(x):
13
+ # Fast-fail version of all(is_fp_tensor)
14
+ for y in x:
15
+ if not is_fp_tensor(y):
16
+ return False
17
+ return True
18
+ return compat.is_tensor_like(x) and compat.is_floating_point(x)
19
+
20
+ def is_nested(x):
21
+ return isinstance(x, tuple) or isinstance(x, list)
22
+
23
+ def should_cache(x):
24
+ if is_nested(x):
25
+ # Fast-fail version of all(should_cache)
26
+ for y in x:
27
+ if not should_cache(y):
28
+ return False
29
+ return True
30
+ return isinstance(x, torch.nn.parameter.Parameter) and \
31
+ type_string(x) == 'FloatTensor'
32
+
33
+ def collect_fp_tensor_types(args, kwargs):
34
+ def collect_types(x, types):
35
+ if is_nested(x):
36
+ for y in x:
37
+ collect_types(y, types)
38
+ else:
39
+ types.add(type_string(x))
40
+
41
+ all_args = itertools.chain(args, kwargs.values())
42
+ types = set()
43
+ for x in all_args:
44
+ if is_fp_tensor(x):
45
+ collect_types(x, types)
46
+ return types
47
+
48
+ def type_string(x):
49
+ return x.type().split('.')[-1]
50
+
51
+ def maybe_half(x, name='', verbose=False):
52
+ if is_nested(x):
53
+ return type(x)([maybe_half(y) for y in x])
54
+
55
+ if not x.is_cuda or type_string(x) == 'HalfTensor':
56
+ return x
57
+ else:
58
+ if verbose:
59
+ print('Float->Half ({})'.format(name))
60
+ return x.half()
61
+
62
+ def maybe_float(x, name='', verbose=False):
63
+ if is_nested(x):
64
+ return type(x)([maybe_float(y) for y in x])
65
+
66
+ if not x.is_cuda or type_string(x) == 'FloatTensor':
67
+ return x
68
+ else:
69
+ if verbose:
70
+ print('Half->Float ({})'.format(name))
71
+ return x.float()
72
+
73
+ # NB: returneds casted `args`, mutates `kwargs` in-place
74
+ def casted_args(cast_fn, args, kwargs):
75
+ new_args = []
76
+ for x in args:
77
+ if is_fp_tensor(x):
78
+ new_args.append(cast_fn(x))
79
+ else:
80
+ new_args.append(x)
81
+ for k in kwargs:
82
+ val = kwargs[k]
83
+ if is_fp_tensor(val):
84
+ kwargs[k] = cast_fn(val)
85
+ return new_args
86
+
87
+ def cached_cast(cast_fn, x, cache):
88
+ if is_nested(x):
89
+ return type(x)([cached_cast(y) for y in x])
90
+ if x in cache:
91
+ cached_x = cache[x]
92
+ if x.requires_grad and cached_x.requires_grad:
93
+ # Make sure x is actually cached_x's autograd parent.
94
+ if cached_x.grad_fn.next_functions[1][0].variable is not x:
95
+ raise RuntimeError("x and cache[x] both require grad, but x is not "
96
+ "cache[x]'s parent. This is likely an error.")
97
+ # During eval, it's possible to end up caching casted weights with
98
+ # requires_grad=False. On the next training iter, if cached_x is found
99
+ # and reused from the cache, it will not actually have x as its parent.
100
+ # Therefore, we choose to invalidate the cache (and force refreshing the cast)
101
+ # if x.requires_grad and cached_x.requires_grad do not match.
102
+ #
103
+ # During eval (i.e. running under with torch.no_grad()) the invalidation
104
+ # check would cause the cached value to be dropped every time, because
105
+ # cached_x would always be created with requires_grad=False, while x would
106
+ # still have requires_grad=True. This would render the cache effectively
107
+ # useless during eval. Therefore, if we are running under the no_grad()
108
+ # context manager (torch.is_grad_enabled=False) we elide the invalidation
109
+ # check, and use the cached value even though its requires_grad flag doesn't
110
+ # match. During eval, we don't care that there's no autograd-graph
111
+ # connection between x and cached_x.
112
+ if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad:
113
+ del cache[x]
114
+ else:
115
+ return cached_x
116
+
117
+ casted_x = cast_fn(x)
118
+ cache[x] = casted_x
119
+ return casted_x
120
+
121
+ def verbosify(cast_fn, fn_name, verbose):
122
+ if verbose:
123
+ return functools.partial(cast_fn, name=fn_name, verbose=verbose)
124
+ else:
125
+ return cast_fn
126
+
127
+ def as_inplace(fns):
128
+ for x in fns:
129
+ yield x + '_'
130
+
131
+ def has_func(mod, fn):
132
+ if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
133
+ return fn in mod.function_classes
134
+ elif isinstance(mod, dict):
135
+ return fn in mod
136
+ else:
137
+ return hasattr(mod, fn)
138
+
139
+ def get_func(mod, fn):
140
+ if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
141
+ return mod.function_classes[fn]
142
+ elif isinstance(mod, dict):
143
+ return mod[fn]
144
+ else:
145
+ return getattr(mod, fn)
146
+
147
+ def set_func(mod, fn, new_fn):
148
+ if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
149
+ mod.function_classes[fn] = new_fn
150
+ elif isinstance(mod, dict):
151
+ mod[fn] = new_fn
152
+ else:
153
+ setattr(mod, fn, new_fn)
154
+
155
+ def set_func_save(handle, mod, fn, new_fn):
156
+ cur_fn = get_func(mod, fn)
157
+ handle._save_func(mod, fn, cur_fn)
158
+ set_func(mod, fn, new_fn)
159
+
160
+ # A couple problems get solved here:
161
+ # - The flat_weight buffer is disconnected from autograd graph,
162
+ # so the fp16 weights need to be derived from the input weights
163
+ # to this forward call, not the flat buffer.
164
+ # - The ordering of weights in the flat buffer is...idiosyncratic.
165
+ # First problem is solved with combination of set_ (to set up
166
+ # correct storage) and copy_ (so the fp16 weight derives from the
167
+ # fp32 one in autograd.
168
+ # Second is solved by doing ptr arithmetic on the fp32 weights
169
+ # to derive the correct offset.
170
+ #
171
+ # TODO: maybe this should actually use
172
+ # `torch._cudnn_rnn_flatten_weight`? But then I need to call
173
+ # on first iter and cache the right offsets. Ugh.
174
+ def synthesize_flattened_rnn_weights(fp32_weights,
175
+ fp16_flat_tensor,
176
+ rnn_fn='',
177
+ verbose=False):
178
+ fp16_weights = []
179
+ fp32_base_ptr = fp32_weights[0][0].data_ptr()
180
+ for layer_weights in fp32_weights:
181
+ fp16_layer_weights = []
182
+ for w_fp32 in layer_weights:
183
+ w_fp16 = w_fp32.new().half()
184
+ offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
185
+ w_fp16.set_(fp16_flat_tensor.storage(),
186
+ offset,
187
+ w_fp32.shape)
188
+ w_fp16.copy_(w_fp32)
189
+ if verbose:
190
+ print('Float->Half ({})'.format(rnn_fn))
191
+ fp16_layer_weights.append(w_fp16)
192
+ fp16_weights.append(fp16_layer_weights)
193
+ return fp16_weights
194
+
195
+ # Roughly same as above, just the `fp32_weights` aren't nested.
196
+ # Code kept separate for readability.
197
+ def new_synthesize_flattened_rnn_weights(fp32_weights,
198
+ fp16_flat_tensor,
199
+ rnn_fn='',
200
+ verbose=False):
201
+ fp16_weights = []
202
+ fp32_base_ptr = fp32_weights[0].data_ptr()
203
+ for w_fp32 in fp32_weights:
204
+ w_fp16 = w_fp32.new().half()
205
+ offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
206
+ w_fp16.set_(fp16_flat_tensor.storage(),
207
+ offset,
208
+ w_fp32.shape)
209
+ w_fp16.copy_(w_fp32)
210
+ if verbose:
211
+ print('Float->Half ({})'.format(rnn_fn))
212
+ fp16_weights.append(w_fp16)
213
+ return fp16_weights
apex/apex/amp/wrap.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import compat
2
+ from . import utils
3
+ from ._amp_state import _amp_state
4
+ from . import rnn_compat
5
+
6
+ import functools
7
+
8
+ import torch
9
+
10
+ def make_cast_wrapper(orig_fn, cast_fn, handle,
11
+ try_caching=False):
12
+ @functools.wraps(orig_fn)
13
+ def wrapper(*args, **kwargs):
14
+ if not handle.is_active():
15
+ return orig_fn(*args, **kwargs)
16
+
17
+ if try_caching and handle.has_cache:
18
+ args = list(args)
19
+ for i in range(len(args)):
20
+ if utils.should_cache(args[i]):
21
+ args[i] = utils.cached_cast(cast_fn, args[i], handle.cache)
22
+ for k in kwargs:
23
+ if utils.should_cache(kwargs[k]):
24
+ kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache)
25
+ new_args = utils.casted_args(cast_fn,
26
+ args,
27
+ kwargs)
28
+ return orig_fn(*new_args, **kwargs)
29
+ return wrapper
30
+
31
+ def cached_cast(mod, fn, cast_fn, handle,
32
+ try_caching=False, verbose=False):
33
+ if not utils.has_func(mod, fn):
34
+ return
35
+
36
+ orig_fn = utils.get_func(mod, fn)
37
+ cast_fn = utils.verbosify(cast_fn, fn, verbose)
38
+ wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
39
+ utils.set_func_save(handle, mod, fn, wrapper)
40
+
41
+ # `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
42
+ # Annoyingly, make_promote_wrapper still uses the global handle. Once everyone
43
+ # is on the new API and I am free to get rid of handle, I can clean this up.
44
+ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
45
+ @functools.wraps(orig_fn)
46
+ def wrapper(*args, **kwargs):
47
+ if not _amp_state.handle.is_active():
48
+ return orig_fn(*args, **kwargs)
49
+
50
+ types = utils.collect_fp_tensor_types(args, kwargs)
51
+
52
+ if len(types) <= 1:
53
+ return orig_fn(*args, **kwargs)
54
+ elif len(types) == 2 and types == set(['HalfTensor', 'FloatTensor']):
55
+ new_args = utils.casted_args(cast_fn,
56
+ args,
57
+ kwargs)
58
+ return orig_fn(*new_args, **kwargs)
59
+ else:
60
+ raise NotImplementedError('Do not know how to handle ' +
61
+ 'these types to promote: {}'
62
+ .format(types))
63
+ return wrapper
64
+
65
+ def promote(mod, fn, handle, verbose=False):
66
+ orig_fn = utils.get_func(mod, fn)
67
+ maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
68
+ wrapper = make_promote_wrapper(orig_fn, maybe_float)
69
+ utils.set_func_save(handle, mod, fn, wrapper)
70
+
71
+ def sequence_promote(mod, fn, handle, verbose=False):
72
+ orig_fn = utils.get_func(mod, fn)
73
+ maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
74
+ @functools.wraps(orig_fn)
75
+ def wrapper(seq, *args, **kwargs):
76
+ if not _amp_state.handle.is_active():
77
+ return orig_fn(seq, *args, **kwargs)
78
+
79
+ types = set([utils.type_string(x) for x in seq])
80
+ if len(types) <= 1:
81
+ return orig_fn(seq, *args, **kwargs)
82
+ elif types == set(['HalfTensor', 'FloatTensor']):
83
+ cast_seq = utils.casted_args(maybe_float,
84
+ seq, {})
85
+ return orig_fn(cast_seq, *args, **kwargs)
86
+ else:
87
+ # TODO: other mixed-type cases aren't due to amp.
88
+ # Just pass through?
89
+ return orig_fn(seq, *args, **kwargs)
90
+ utils.set_func_save(handle, mod, fn, wrapper)
91
+
92
+ def promote_match_arg0(mod, fn, handle, verbose=False):
93
+ if not utils.has_func(mod, fn):
94
+ return
95
+
96
+ orig_fn = utils.get_func(mod, fn)
97
+ @functools.wraps(orig_fn)
98
+ def wrapper(arg0, *args, **kwargs):
99
+ assert compat.is_tensor_like(arg0)
100
+ if not _amp_state.handle.is_active():
101
+ return orig_fn(arg0, *args, **kwargs)
102
+
103
+ if utils.type_string(arg0) == 'HalfTensor':
104
+ cast_fn = utils.maybe_half
105
+ elif utils.type_string(arg0) == 'FloatTensor':
106
+ cast_fn = utils.maybe_float
107
+ else:
108
+ return orig_fn(arg0, *args, **kwargs)
109
+ cast_fn = utils.verbosify(cast_fn, fn, verbose)
110
+ new_args = utils.casted_args(cast_fn, args, kwargs)
111
+ return orig_fn(arg0, *new_args, **kwargs)
112
+ utils.set_func_save(handle, mod, fn, wrapper)
113
+
114
+ def err_if_any_half(mod, fn, handle, custom_err_msg=None):
115
+ if not utils.has_func(mod, fn):
116
+ return
117
+
118
+ orig_fn = utils.get_func(mod, fn)
119
+ @functools.wraps(orig_fn)
120
+ def wrapper(*args, **kwargs):
121
+ types = utils.collect_fp_tensor_types(args, kwargs)
122
+ if 'HalfTensor' in types:
123
+ if custom_err_msg:
124
+ raise NotImplementedError(custom_err_msg)
125
+ else:
126
+ raise NotImplementedError('Cannot call in-place function ' +
127
+ '{} with fp16 arguments.'.format(fn))
128
+ else:
129
+ return orig_fn(*args, **kwargs)
130
+ utils.set_func_save(handle, mod, fn, wrapper)
131
+
132
+ def err_if_arg0_half(mod, fn, handle, verbose=False):
133
+ if not utils.has_func(mod, fn):
134
+ return
135
+
136
+ orig_fn = utils.get_func(mod, fn)
137
+ @functools.wraps(orig_fn)
138
+ def wrapper(arg0, *args, **kwargs):
139
+ assert compat.is_tensor_like(arg0)
140
+ if utils.type_string(arg0) == 'HalfTensor':
141
+ raise NotImplementedError('Cannot call in-place method ' +
142
+ '{} on fp16 Tensors.'.format(fn))
143
+ else:
144
+ cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
145
+ new_args = utils.casted_args(cast_fn, args, kwargs)
146
+ return orig_fn(arg0, *new_args, **kwargs)
147
+ utils.set_func_save(handle, mod, fn, wrapper)
148
+
149
+ # Current RNN approach:
150
+ # - Wrap top-level `RNN` function in thnn backend
151
+ # - Will call into either CudnnRNN or AutogradRNN
152
+ # - Each of these are factory functions that return a per-iter
153
+ # `forward` function
154
+ # - We interpose on the factory function to:
155
+ # 1) Interpose on the actual forward function and put in casts
156
+ # 2) Insert an fp16 `flat_weight` if necessary
157
+ def rnn_cast(backend, fn, handle, verbose=False):
158
+ orig_rnn = utils.get_func(backend, fn)
159
+ @functools.wraps(orig_rnn)
160
+ def rnn_wrapper(*args, **kwargs):
161
+ flat_weight = kwargs.get('flat_weight')
162
+ if flat_weight is not None:
163
+ # We replace `flat_weight` with an uninitialized fp16
164
+ # Tensor. The "actual" weight tensors (provided in `forward`),
165
+ # will then be set up as ptrs into the buffer and have the
166
+ # corresponding fp32 values copied in.
167
+ # We need to call `copy` on the "actual" weights so that the
168
+ # autograd graph correctly backprops from the wgrads computed
169
+ # inside cuDNN (on fp16 weights) into the fp32 weights.
170
+ assert utils.type_string(flat_weight) == 'FloatTensor'
171
+ if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
172
+ # Pre-0.4. A little slower, since it zeros out memory.
173
+ flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
174
+ else:
175
+ flat_weight_fp16 = torch.empty_like(flat_weight,
176
+ dtype=torch.float16)
177
+ kwargs['flat_weight'] = flat_weight_fp16
178
+ else:
179
+ flat_weight_fp16 = None
180
+
181
+ forward = orig_rnn(*args, **kwargs)
182
+ @functools.wraps(forward)
183
+ def fwd_wrapper(*fargs, **fkwargs):
184
+ assert len(fargs) == 3 or len(fargs) == 4
185
+ inputs, weights, hiddens = fargs[:3]
186
+ assert utils.is_fp_tensor(inputs)
187
+ assert isinstance(weights, list)
188
+ cast_fn = utils.verbosify(utils.maybe_half,
189
+ fn,
190
+ verbose)
191
+ new_args = []
192
+
193
+ # 0) Inputs
194
+ new_args.append(cast_fn(inputs))
195
+
196
+ # 1) Weights
197
+ if flat_weight_fp16 is not None:
198
+ fp16_weights = utils.synthesize_flattened_rnn_weights(
199
+ weights, flat_weight_fp16, fn, verbose)
200
+ else:
201
+ fp16_weights = [[cast_fn(w) for w in layer]
202
+ for layer in weights]
203
+ new_args.append(fp16_weights)
204
+
205
+ # 2) Inputs: either a tuple (for LSTM) or single tensor
206
+ if isinstance(hiddens, tuple):
207
+ new_args.append(tuple(cast_fn(x) for x in hiddens))
208
+ elif utils.is_fp_tensor(hiddens):
209
+ new_args.append(cast_fn(hiddens))
210
+ else:
211
+ # Hiddens can, in principle, be `None` -- pass through
212
+ new_args.append(hiddens)
213
+
214
+ # 3) Batch sizes (0.4 or later only)
215
+ if len(fargs) == 4:
216
+ new_args.append(fargs[3])
217
+
218
+ return forward(*new_args, **fkwargs)
219
+ return fwd_wrapper
220
+ utils.set_func_save(handle, backend, fn, rnn_wrapper)
221
+
222
+ def new_rnn_cast(fn, handle, verbose=False):
223
+ # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744
224
+ # For rnn backend calls that route through _rnn_impls, we must patch the ref
225
+ # that _rnn_impls stashed. For rnn backend calls that directly invoke
226
+ # _VF.<backend>, e.g. _VF.lstm, we can patch onto VariableFunctionsShim,
227
+ # which in turn has patched the ref named "_VF" in torch.nn.modules.rnn.
228
+ if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn):
229
+ mod = torch.nn.modules.rnn._rnn_impls
230
+ else:
231
+ mod = torch.nn.modules.rnn._VF
232
+ assert isinstance(mod, rnn_compat.VariableFunctionsShim)
233
+ fn = fn.lower()
234
+ orig_fn = utils.get_func(mod, fn)
235
+ cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
236
+ @functools.wraps(orig_fn)
237
+ def wrapper(*args, **kwargs):
238
+ # Exact call signature from modules/rnn.py
239
+ assert len(args) == 9
240
+ assert len(kwargs) == 0
241
+
242
+ if not _amp_state.handle.is_active():
243
+ return orig_fn(*args, **kwargs)
244
+
245
+ if isinstance(args[6], bool):
246
+ params_idx = 2 # Not PackedSequence case
247
+ else:
248
+ params_idx = 3 # PackedSequence case
249
+
250
+ new_args = []
251
+ for i, arg in enumerate(args):
252
+ if i == params_idx:
253
+ num_params = sum([x.numel() for x in arg])
254
+ fp16_weight_buf = args[0].new_empty((num_params,),
255
+ dtype=torch.half)
256
+ casted_weights = utils.new_synthesize_flattened_rnn_weights(
257
+ arg, fp16_weight_buf, fn, verbose)
258
+ new_args.append(casted_weights)
259
+ elif utils.is_fp_tensor(arg):
260
+ new_args.append(cast_fn(arg))
261
+ else:
262
+ new_args.append(arg)
263
+
264
+ return orig_fn(*new_args)
265
+ utils.set_func_save(handle, mod, fn, wrapper)
266
+
267
+ def disable_casts(mod, fn, handle):
268
+ if not utils.has_func(mod, fn):
269
+ return
270
+
271
+ orig_fn = utils.get_func(mod, fn)
272
+ @functools.wraps(orig_fn)
273
+ def wrapper(*args, **kwargs):
274
+ with handle._disable_casts():
275
+ return orig_fn(*args, **kwargs)
276
+ utils.set_func_save(handle, mod, fn, wrapper)
apex/apex/fp16_utils/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fp16_optimizer.py contains `FP16_Optimizer`, a Python class designed to wrap an existing Pytorch optimizer and automatically enable master parameters and loss scaling in a manner transparent to the user. To use `FP16_Optimizer`, only two lines of one's Python model need to change.
2
+
3
+ #### [FP16_Optimizer API documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling)
4
+
5
+ #### [Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple)
6
+
7
+ #### [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
8
+
9
+ #### [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model)
10
+
11
+
12
+ fp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses.
13
+
14
+ #### [Manual management documentation](https://nvidia.github.io/apex/fp16_utils.html#manual-master-parameter-management)
15
+
16
+ The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) and [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) directories also contain `main.py` files that demonstrate manual management of master parameters and static loss scaling. These examples illustrate what sort of operations `FP16_Optimizer` is performing automatically.
apex/apex/fp16_utils/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fp16util import (
2
+ BN_convert_float,
3
+ network_to_half,
4
+ prep_param_lists,
5
+ model_grads_to_master_grads,
6
+ master_params_to_model_params,
7
+ tofp16,
8
+ to_python_float,
9
+ clip_grad_norm,
10
+ convert_module,
11
+ convert_network,
12
+ FP16Model,
13
+ )
14
+
15
+ from .fp16_optimizer import FP16_Optimizer
16
+ from .loss_scaler import LossScaler, DynamicLossScaler
apex/apex/fp16_utils/fp16_optimizer.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.autograd import Variable
4
+ from torch.nn.parameter import Parameter
5
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
6
+
7
+ from ..amp._amp_state import _amp_state, maybe_print
8
+ from ..amp.scaler import LossScaler
9
+ from ..multi_tensor_apply import multi_tensor_applier
10
+ from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
11
+
12
+ # TODO: Update overflow check + downscale to use Carl's fused kernel.
13
+ class FP16_Optimizer(object):
14
+ """
15
+ :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
16
+ and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
17
+ For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
18
+ and changing the call to ``backward``.
19
+
20
+ Example::
21
+
22
+ model = torch.nn.Linear(D_in, D_out).cuda().half()
23
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
24
+ # Name the FP16_Optimizer instance to replace the existing optimizer
25
+ # (recommended but not required):
26
+ optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
27
+ ...
28
+ # loss.backward() becomes:
29
+ optimizer.backward(loss)
30
+ ...
31
+
32
+ Example with dynamic loss scaling::
33
+
34
+ ...
35
+ optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
36
+ # optional arg to control dynamic loss scaling behavior
37
+ # dynamic_loss_args={'scale_window' : 500})
38
+ # Usually, dynamic_loss_args is not necessary.
39
+
40
+ Args:
41
+ init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
42
+ static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
43
+ dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
44
+ dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`LossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`LossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`LossScaler`'s defaults will be used.
45
+ verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
46
+
47
+ ``init_optimizer`` is expected to have been constructed in the ordinary way.
48
+ It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
49
+ named to replace ``init_optimizer``, for two reasons:
50
+ First, it means that references to the same name
51
+ later in the file will not have to change.
52
+ Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
53
+ modify ``init_optimizer``. If you do choose a unique name for the new
54
+ :class:`FP16_Optimizer` instance, you should only work with this new instance,
55
+ because the preexisting optimizer might no longer behave as expected.
56
+
57
+ ``init_optimizer`` may be any Pytorch optimizer.
58
+ It may contain a mixture of fp16 and fp32 parameters organized into any number of
59
+ ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
60
+ ingest these ``param_groups`` and remember them.
61
+
62
+ Calls to ::
63
+
64
+ loss.backward()
65
+
66
+ must be replaced with ::
67
+
68
+ optimizer.backward(loss)
69
+
70
+ because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
71
+ loss scaling and copies to master gradients.
72
+
73
+ .. note::
74
+ Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
75
+ are downscaled before being applied. This means that adjusting the loss scale, or using
76
+ dynamic loss scaling, should not require retuning the learning rate or any other
77
+ hyperparameters.
78
+
79
+
80
+ **Advanced options**
81
+
82
+ **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
83
+ See docstring for :attr:`step`.
84
+
85
+ **Gradient clipping**: Use :attr:`clip_master_grads`.
86
+
87
+ **Multiple losses**: If your model accumulates gradients from multiple losses,
88
+ this can be made more efficient by supplying ``update_master_grads=False``
89
+ to :attr:`backward`. See docstring for :attr:`backward`.
90
+
91
+ **Manually adjusting loss scale**: The current loss scale can be retrieved or set via ::
92
+
93
+ print(optimizer.loss_scale)
94
+ optimizer.loss_scale = new_loss_scale
95
+
96
+ For static loss scaling, manually adjusting the loss scale over time is a reasonable
97
+ thing to do. During later epochs, gradients may become smaller, and a
98
+ higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
99
+ scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
100
+ the loss scale is not recommended.
101
+
102
+ **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
103
+ Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
104
+ should still work as intended.
105
+ """
106
+
107
+ def __init__(self,
108
+ init_optimizer,
109
+ static_loss_scale=1.0,
110
+ dynamic_loss_scale=False,
111
+ dynamic_loss_args=None,
112
+ verbose=True):
113
+ if not torch.cuda.is_available:
114
+ raise SystemError("Cannot use fp16 without CUDA.")
115
+
116
+ self.verbose = verbose
117
+
118
+ self.optimizer = init_optimizer
119
+ # init_state_dict sets up an alternative way to cast per-param state tensors.
120
+ # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
121
+ # init_state_dict = init_optimizer.state_dict()
122
+
123
+ self.fp16_groups = []
124
+ self.fp32_from_fp16_groups = []
125
+ self.fp32_from_fp32_groups = []
126
+ for i, param_group in enumerate(self.optimizer.param_groups):
127
+ self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
128
+ fp16_params_this_group = []
129
+ fp32_params_this_group = []
130
+ fp32_from_fp16_params_this_group = []
131
+ for i, param in enumerate(param_group['params']):
132
+ if param.requires_grad:
133
+ if param.type() == 'torch.cuda.HalfTensor':
134
+ self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
135
+ .format(param.size()))
136
+ fp16_params_this_group.append(param)
137
+ master_param = param.detach().clone().float()
138
+ master_param.requires_grad = True
139
+ param_group['params'][i] = master_param
140
+ fp32_from_fp16_params_this_group.append(master_param)
141
+ # Reset existing state dict key to the new master param.
142
+ # We still need to recast per-param state tensors, if any, to FP32.
143
+ if param in self.optimizer.state:
144
+ self.optimizer.state[master_param] = self.optimizer.state.pop(param)
145
+ elif param.type() == 'torch.cuda.FloatTensor':
146
+ self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
147
+ .format(param.size()))
148
+ fp32_params_this_group.append(param)
149
+ param_group['params'][i] = param
150
+ else:
151
+ raise TypeError("Wrapped parameters must be either "
152
+ "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
153
+ "Received {}".format(param.type()))
154
+
155
+ self.fp16_groups.append(fp16_params_this_group)
156
+ self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
157
+ self.fp32_from_fp32_groups.append(fp32_params_this_group)
158
+
159
+ self.all_fp16_params = []
160
+ for group in self.fp16_groups:
161
+ self.all_fp16_params += group
162
+
163
+ self.all_fp32_from_fp16_params = []
164
+ for group in self.fp32_from_fp16_groups:
165
+ self.all_fp32_from_fp16_params += group
166
+
167
+ self.all_fp32_from_fp32_params = []
168
+ for group in self.fp32_from_fp32_groups:
169
+ self.all_fp32_from_fp32_params += group
170
+
171
+ # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
172
+ self.optimizer.load_state_dict(self.optimizer.state_dict())
173
+ # alternative way to cast per-param state tensors:
174
+ # self.optimizer.load_state_dict(init_state_dict)
175
+
176
+ if dynamic_loss_scale:
177
+ self.dynamic_loss_scale = True
178
+ if dynamic_loss_args is not None:
179
+ self.loss_scaler = LossScaler("dynamic", **dynamic_loss_args)
180
+ else:
181
+ self.loss_scaler = LossScaler("dynamic")
182
+ else:
183
+ self.dynamic_loss_scale = False
184
+ self.loss_scaler = LossScaler(static_loss_scale)
185
+
186
+ self.overflow = False
187
+ self.first_closure_call_this_step = True
188
+
189
+ self.clip_grad_norm = clip_grad_norm
190
+
191
+ # TODO: Centralize exposure and import error checking for the C backend.
192
+ if multi_tensor_applier.available:
193
+ import amp_C
194
+ self.multi_tensor_scale = amp_C.multi_tensor_scale
195
+ self._dummy_overflow_buf = torch.cuda.IntTensor([0]);
196
+
197
+ # Having self.maybe_print distinct from _amp_state.maybe_print is another artifact
198
+ # of having to support FP16_Optimizer separately, for the time being.
199
+ def maybe_print(self, msg):
200
+ if self.verbose:
201
+ print(msg)
202
+
203
+ def __getstate__(self):
204
+ raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
205
+
206
+ def __setstate__(self, state):
207
+ raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
208
+
209
+ def zero_grad(self, set_grads_to_None=False):
210
+ """
211
+ Zero fp32 and fp16 parameter grads.
212
+ """
213
+ # In principle, only the .grad attributes of the model params need to be zeroed,
214
+ # because gradients are copied into the FP32 master params. However, we zero
215
+ # all gradients owned by the optimizer, just to be safe:
216
+ for group in self.optimizer.param_groups:
217
+ for p in group['params']:
218
+ if set_grads_to_None:
219
+ p.grad = None
220
+ else:
221
+ if p.grad is not None:
222
+ p.grad.detach_()
223
+ p.grad.zero_()
224
+
225
+ # Zero fp16 gradients owned by the model:
226
+ for fp16_group in self.fp16_groups:
227
+ for param in fp16_group:
228
+ if set_grads_to_None:
229
+ param.grad = None
230
+ else:
231
+ if param.grad is not None:
232
+ param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
233
+ param.grad.zero_()
234
+
235
+ # Should not be used anymore.
236
+ # def _check_overflow(self):
237
+ # params = []
238
+ # for group in self.fp16_groups:
239
+ # for param in group:
240
+ # params.append(param)
241
+ # for group in self.fp32_from_fp32_groups:
242
+ # for param in group:
243
+ # params.append(param)
244
+ # self.overflow = self.loss_scaler.has_overflow(params)
245
+
246
+ # def _update_scale(self, has_overflow=False):
247
+ # self.loss_scaler.update_scale(has_overflow)
248
+
249
+ def _master_params_to_model_params(self):
250
+ if multi_tensor_applier.available:
251
+ if len(self.all_fp16_params) > 0:
252
+ multi_tensor_applier(
253
+ self.multi_tensor_scale,
254
+ self._dummy_overflow_buf,
255
+ [self.all_fp32_from_fp16_params, self.all_fp16_params],
256
+ 1.0)
257
+ else:
258
+ for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
259
+ master_params_to_model_params(fp16_group, fp32_from_fp16_group)
260
+
261
+ # To consider: Integrate distributed with this wrapper by registering a hook on each variable
262
+ # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
263
+ # def _model_grads_to_master_grads(self):
264
+ # for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
265
+ # model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
266
+
267
+ # def _downscale_master(self):
268
+ # if self.loss_scale != 1.0:
269
+ # for group in self.optimizer.param_groups:
270
+ # for param in group['params']:
271
+ # if param.grad is not None:
272
+ # param.grad.data.mul_(1./self.loss_scale)
273
+
274
+ def clip_master_grads(self, max_norm, norm_type=2):
275
+ """
276
+ Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
277
+
278
+ Args:
279
+ max_norm (float or int): max norm of the gradients
280
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
281
+ infinity norm.
282
+
283
+ Returns:
284
+ Total norm of the current fp32 gradients (viewed as a single vector).
285
+
286
+ .. warning::
287
+ Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
288
+ """
289
+ if not self.overflow:
290
+ fp32_params = []
291
+ for param_group in self.optimizer.param_groups:
292
+ for param in param_group['params']:
293
+ fp32_params.append(param)
294
+ return self.clip_grad_norm(fp32_params, max_norm, norm_type)
295
+ else:
296
+ return -1
297
+
298
+ def state_dict(self):
299
+ """
300
+ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
301
+ This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
302
+ of the contained Pytorch optimizer.
303
+ Example::
304
+
305
+ checkpoint = {}
306
+ checkpoint['model'] = model.state_dict()
307
+ checkpoint['optimizer'] = optimizer.state_dict()
308
+ torch.save(checkpoint, "saved.pth")
309
+ """
310
+ state_dict = {}
311
+ state_dict['loss_scaler'] = self.loss_scaler
312
+ state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
313
+ state_dict['overflow'] = self.overflow
314
+ state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
315
+ state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
316
+ state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
317
+ return state_dict
318
+
319
+ def load_state_dict(self, state_dict):
320
+ """
321
+ Loads a state_dict created by an earlier call to state_dict().
322
+ If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
323
+ whose parameters in turn came from ``model``, it is expected that the user
324
+ will call ``model.load_state_dict()`` before
325
+ ``fp16_optimizer_instance.load_state_dict()`` is called.
326
+
327
+ Example::
328
+
329
+ model = torch.nn.Linear(D_in, D_out).cuda().half()
330
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
331
+ optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
332
+ ...
333
+ checkpoint = torch.load("saved.pth")
334
+ model.load_state_dict(checkpoint['model'])
335
+ optimizer.load_state_dict(checkpoint['optimizer'])
336
+ """
337
+ # I think it should actually be ok to reload the optimizer before the model.
338
+ self.loss_scaler = state_dict['loss_scaler']
339
+ self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
340
+ self.overflow = state_dict['overflow']
341
+ self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
342
+ self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
343
+ # At this point, the optimizer's references to the model's fp32 parameters are up to date.
344
+ # The optimizer's hyperparameters and internal buffers are also up to date.
345
+ # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
346
+ # out of date. There are two options.
347
+ # 1: Refresh the master params from the model's fp16 params.
348
+ # This requires less storage but incurs precision loss.
349
+ # 2: Save and restore the fp32 master copies separately.
350
+ # We choose option 2.
351
+ #
352
+ # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
353
+ # of their associated parameters, because it's possible those buffers might not exist yet in
354
+ # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
355
+ # constructed in the same way as the one whose state_dict we are loading, the same master params
356
+ # are guaranteed to exist, so we can just copy_() from the saved master params.
357
+ for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
358
+ for current, saved in zip(current_group, saved_group):
359
+ current.data.copy_(saved.data)
360
+
361
+ def step(self, closure=None): # could add clip option.
362
+ """
363
+ If no closure is supplied, :attr:`step` should be called after
364
+ ``fp16_optimizer_obj.backward(loss)``.
365
+ :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
366
+ :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
367
+ originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
368
+ another forward pass using their model.
369
+
370
+ If a closure is supplied, :attr:`step` may be called without a prior call to
371
+ :attr:`backward(loss)`.
372
+ This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
373
+ However, the user should take care that any ``loss.backward()`` call within the closure
374
+ has been replaced by ``fp16_optimizer_obj.backward(loss)``.
375
+
376
+ Args:
377
+ closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.
378
+
379
+ Example with closure::
380
+
381
+ # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
382
+ # existing pytorch optimizer.
383
+ for input, target in dataset:
384
+ def closure():
385
+ optimizer.zero_grad()
386
+ output = model(input)
387
+ loss = loss_fn(output, target)
388
+ # loss.backward() becomes:
389
+ optimizer.backward(loss)
390
+ return loss
391
+ optimizer.step(closure)
392
+
393
+ .. warning::
394
+ Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.
395
+
396
+ .. _`ordinary Pytorch optimizer use`:
397
+ http://pytorch.org/docs/master/optim.html#optimizer-step-closure
398
+ """
399
+
400
+ scale = self.loss_scaler.loss_scale()
401
+ # To consider: Should this be in step(), or update_master_grads? It works either way,
402
+ # but I should make it consistent with the Amp control flow, which updates the scale
403
+ # during backward context manager exit.
404
+ # self._update_scale(self.overflow)
405
+
406
+ if self.overflow:
407
+ # Using _amp_state.maybe_print instead of self.print here is intentional.
408
+ maybe_print("Gradient overflow. Skipping step, reducing " +
409
+ "loss scale to {}".format(self.loss_scaler.loss_scale()))
410
+ return
411
+
412
+ if closure is not None:
413
+ retval = self._step_with_closure(closure)
414
+ else:
415
+ # torch.cuda.nvtx.range_push("pytorch optimizer step")
416
+ retval = self.optimizer.step()
417
+ # torch.cuda.nvtx.range_pop()
418
+
419
+ self._master_params_to_model_params()
420
+
421
+ return retval
422
+
423
+ def _step_with_closure(self, closure):
424
+ def wrapped_closure():
425
+ # helpful for debugging
426
+ # print("Calling wrapped_closure, first_closure_call_this_step = {}"
427
+ # .format(self.first_closure_call_this_step))
428
+ if self.first_closure_call_this_step:
429
+ # We expect that the fp16 params are initially fresh on entering self.step(),
430
+ # so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
431
+ # is called within self.optimizer.step().
432
+ self.first_closure_call_this_step = False
433
+ else:
434
+ # If self.optimizer.step() internally calls wrapped_closure more than once,
435
+ # it may update the fp32 params after each call. However, self.optimizer
436
+ # doesn't know about the fp16 params at all. If the fp32 params get updated,
437
+ # we can't rely on self.optimizer to refresh the fp16 params. We need
438
+ # to handle that manually:
439
+ self._master_params_to_model_params()
440
+ # Our API expects the user to give us ownership of the backward() call by
441
+ # replacing all calls to loss.backward() with optimizer.backward(loss).
442
+ # This requirement holds whether or not the call to backward() is made within a closure.
443
+ # If the user is properly calling optimizer.backward(loss) within "closure,"
444
+ # calling closure() here will give the fp32 master params fresh gradients
445
+ # for the optimizer to play with, so all wrapped_closure needs to do is call
446
+ # closure() and return the loss.
447
+ temp_loss = closure()
448
+ while(self.overflow):
449
+ scale = self.loss_scaler.loss_scale()
450
+ # self._update_scale(self.overflow) # now done at the end of backward
451
+ print("OVERFLOW within closure! Skipping step, reducing loss scale to {}".format(
452
+ self.loss_scaler.loss_scale()))
453
+ temp_loss = closure()
454
+ return temp_loss
455
+
456
+ retval = self.optimizer.step(wrapped_closure)
457
+
458
+ self.first_closure_call_this_step = True
459
+
460
+ return retval
461
+
462
+ def backward(self, loss, update_master_grads=True, retain_graph=False):
463
+ """
464
+ :attr:`backward` performs the following conceptual steps:
465
+
466
+ 1. fp32_loss = loss.float() (see first Note below)
467
+ 2. scaled_loss = fp32_loss*loss_scale
468
+ 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).
469
+ 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.
470
+ 5. Finally, master grads are divided by loss_scale.
471
+
472
+ In this way, after :attr:`backward`, the master params have fresh gradients,
473
+ and :attr:`step` may be called.
474
+
475
+ .. note::
476
+ :attr:`backward` internally converts the loss to fp32 before applying the loss scale.
477
+ This provides some additional safety against overflow if the user has supplied an
478
+ fp16 loss value.
479
+ However, for maximum overflow safety, the user should
480
+ compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
481
+ :attr:`backward`.
482
+
483
+ .. warning::
484
+ The gradients found in a model's leaves after the call to
485
+ :attr:`backward` should not be regarded as valid in general,
486
+ because it's possible
487
+ they have been scaled (and in the case of dynamic loss scaling,
488
+ the scale factor may change over time).
489
+ If the user wants to inspect gradients after a call to :attr:`backward`,
490
+ only the master gradients should be regarded as valid. These can be retrieved via
491
+ :attr:`inspect_master_grad_data()`.
492
+
493
+ Args:
494
+ loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
495
+ update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
496
+ retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
497
+
498
+ Example::
499
+
500
+ # Ordinary operation:
501
+ optimizer.backward(loss)
502
+
503
+ # Naive operation with multiple losses (technically valid, but less efficient):
504
+ # fp32 grads will be correct after the second call, but
505
+ # the first call incurs an unnecessary fp16->fp32 grad copy.
506
+ optimizer.backward(loss1)
507
+ optimizer.backward(loss2)
508
+
509
+ # More efficient way to handle multiple losses:
510
+ # The fp16->fp32 grad copy is delayed until fp16 grads from all
511
+ # losses have been accumulated.
512
+ optimizer.backward(loss1, update_master_grads=False)
513
+ optimizer.backward(loss2, update_master_grads=False)
514
+ optimizer.update_master_grads()
515
+ """
516
+ # To consider: try multiple backward passes using retain_grad=True to find
517
+ # a loss scale that works. After you find a loss scale that works, do a final dummy
518
+ # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
519
+ # discarding the iteration, but probably wouldn't improve overall efficiency.
520
+ scaled_loss = loss.float()*self.loss_scaler.loss_scale()
521
+ scaled_loss.backward(retain_graph=retain_graph)
522
+ if update_master_grads:
523
+ self.update_master_grads()
524
+
525
+ def update_master_grads(self):
526
+ # torch.cuda.nvtx.range_push("update_master_grads")
527
+ """
528
+ Copy the ``.grad`` attribute from stored references to fp16 parameters to
529
+ the ``.grad`` attribute of the fp32 master parameters that are directly
530
+ updated by the optimizer. :attr:`update_master_grads` only needs to be called if
531
+ ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
532
+ """
533
+ # if self.dynamic_loss_scale:
534
+ # self._check_overflow()
535
+ # if self.overflow: return
536
+ # self._model_grads_to_master_grads()
537
+ # self._downscale_master()
538
+ # Use the one-shot multi-tensor apply kernel
539
+ self.loss_scaler.clear_overflow_state()
540
+ if len(self.all_fp16_params) > 0:
541
+ # print("Model grads before")
542
+ # print([param.grad.data for param in self.all_fp16_params])
543
+ # I'm ONLY writing this as an incremental way to make some tests pass until
544
+ # I can refactor the tests as well.
545
+ # FP16_Optimizer should not be used by anyone.
546
+ model_grads = []
547
+ master_grads = []
548
+ for model_param, master_param in zip(self.all_fp16_params,
549
+ self.all_fp32_from_fp16_params):
550
+ if model_param.grad is not None:
551
+ model_grads.append(model_param.grad)
552
+ if master_param.grad is None:
553
+ master_param.grad = torch.empty_like(master_param)
554
+ master_grads.append(master_param.grad)
555
+ self.loss_scaler.unscale(
556
+ model_grads,
557
+ master_grads,
558
+ self.loss_scaler.loss_scale())
559
+ # print("Master grads after")
560
+ # print([param.grad.data for param in self.all_fp32_from_fp16_params])
561
+ if len(self.all_fp32_from_fp32_params) > 0:
562
+ model_grads = []
563
+ master_grads = []
564
+ for model_param, master_param in zip(self.all_fp32_from_fp32_params,
565
+ self.all_fp32_from_fp32_params):
566
+ if model_param.grad is not None:
567
+ model_grads.append(model_param.grad)
568
+ master_grads.append(master_param.grad)
569
+ # print("Model grads before")
570
+ # print([param.grad.data for param in self.all_fp32_from_fp32_params])
571
+ self.loss_scaler.unscale(
572
+ model_grads,
573
+ master_grads,
574
+ self.loss_scaler.loss_scale())
575
+ # print("Master grads after")
576
+ # print([param.grad.data for param in self.all_fp32_from_fp32_params])
577
+ # quit()
578
+ self.overflow = self.loss_scaler.update_scale()
579
+ # torch.cuda.nvtx.range_pop()
580
+
581
+
582
+ def inspect_master_grad_data(self):
583
+ """
584
+ When running with :class:`FP16_Optimizer`,
585
+ ``.grad`` attributes of a model's fp16 leaves should not be
586
+ regarded as truthful, because they might be scaled.
587
+ After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
588
+ the fp32 master params' ``.grad``
589
+ attributes will contain valid gradients properly divided by the loss scale. However,
590
+ because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
591
+ nonintuitive. :attr:`inspect_master_grad_data`
592
+ allows those gradients to be viewed with shapes corresponding to their associated model leaves.
593
+
594
+ Returns:
595
+ List of lists (one list for each parameter group). The list for each parameter group
596
+ is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
597
+ """
598
+ if self.overflow:
599
+ print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
600
+ "Gradients are currently invalid (may be inf, nan, or stale). Returning None.")
601
+ return None
602
+ else:
603
+ # The optimizer owns only references to master params.
604
+ master_grads_data = []
605
+ for param_group in self.optimizer.param_groups:
606
+ master_grads_this_group = []
607
+ for param in param_group['params']:
608
+ if param.grad is not None:
609
+ master_grads_this_group.append(param.grad.data)
610
+ else:
611
+ master_grads_this_group.append(None)
612
+ master_grads_data.append(master_grads_this_group)
613
+ return master_grads_data
614
+
615
+
616
+ # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
617
+ def _get_loss_scale(self):
618
+ return self.loss_scaler.loss_scale()
619
+
620
+ def _set_loss_scale(self, value):
621
+ self.loss_scaler._loss_scale = value
622
+
623
+ loss_scale = property(_get_loss_scale, _set_loss_scale)
624
+
625
+ # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
626
+ def _get_state(self):
627
+ return self.optimizer.state
628
+
629
+ def _set_state(self, value):
630
+ self.optimizer.state = value
631
+
632
+ state = property(_get_state, _set_state)
633
+
634
+ # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
635
+ # (for example, to adjust the learning rate)
636
+ def _get_param_groups(self):
637
+ return self.optimizer.param_groups
638
+
639
+ def _set_param_groups(self, value):
640
+ self.optimizer.param_groups = value
641
+
642
+ param_groups = property(_get_param_groups, _set_param_groups)
643
+
apex/apex/fp16_utils/fp16util.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
5
+
6
+
7
+ class tofp16(nn.Module):
8
+ """
9
+ Utility module that implements::
10
+
11
+ def forward(self, input):
12
+ return input.half()
13
+ """
14
+
15
+ def __init__(self):
16
+ super(tofp16, self).__init__()
17
+
18
+ def forward(self, input):
19
+ return input.half()
20
+
21
+
22
+ def BN_convert_float(module):
23
+ """
24
+ Utility function for network_to_half().
25
+
26
+ Retained for legacy purposes.
27
+ """
28
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
29
+ module.float()
30
+ for child in module.children():
31
+ BN_convert_float(child)
32
+ return module
33
+
34
+
35
+ def network_to_half(network):
36
+ """
37
+ Convert model to half precision in a batchnorm-safe way.
38
+
39
+ Retained for legacy purposes. It is recommended to use FP16Model.
40
+ """
41
+ return nn.Sequential(tofp16(), BN_convert_float(network.half()))
42
+
43
+
44
+ def convert_module(module, dtype):
45
+ """
46
+ Converts a module's immediate parameters and buffers to dtype.
47
+ """
48
+ for param in module.parameters(recurse=False):
49
+ if param is not None:
50
+ if param.data.dtype.is_floating_point:
51
+ param.data = param.data.to(dtype=dtype)
52
+ if param._grad is not None and param._grad.data.dtype.is_floating_point:
53
+ param._grad.data = param._grad.data.to(dtype=dtype)
54
+
55
+ for buf in module.buffers(recurse=False):
56
+ if buf is not None and buf.data.dtype.is_floating_point:
57
+ buf.data = buf.data.to(dtype=dtype)
58
+
59
+
60
+ def convert_network(network, dtype):
61
+ """
62
+ Converts a network's parameters and buffers to dtype.
63
+ """
64
+ for module in network.modules():
65
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
66
+ continue
67
+ convert_module(module, dtype)
68
+ if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase):
69
+ module.flatten_parameters()
70
+ return network
71
+
72
+
73
+ class FP16Model(nn.Module):
74
+ """
75
+ Convert model to half precision in a batchnorm-safe way.
76
+ """
77
+
78
+ def __init__(self, network):
79
+ super(FP16Model, self).__init__()
80
+ self.network = convert_network(network, dtype=torch.half)
81
+
82
+ def forward(self, *inputs):
83
+ inputs = tuple(t.half() for t in inputs)
84
+ return self.network(*inputs)
85
+
86
+
87
+ def backwards_debug_hook(grad):
88
+ raise RuntimeError("master_params recieved a gradient in the backward pass!")
89
+
90
+ def prep_param_lists(model, flat_master=False):
91
+ """
92
+ Creates a list of FP32 master parameters for a given model, as in
93
+ `Training Neural Networks with Mixed Precision: Real Examples`_.
94
+
95
+ Args:
96
+ model (torch.nn.Module): Existing Pytorch model
97
+ flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization.
98
+ Returns:
99
+ A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element.
100
+
101
+ Example::
102
+
103
+ model_params, master_params = prep_param_lists(model)
104
+
105
+ .. warning::
106
+ Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.
107
+
108
+ .. _`Training Neural Networks with Mixed Precision: Real Examples`:
109
+ http://on-demand.gputechconf.com/gtc/2018/video/S81012/
110
+ """
111
+ model_params = [param for param in model.parameters() if param.requires_grad]
112
+
113
+ if flat_master:
114
+ # Give the user some more useful error messages
115
+ try:
116
+ # flatten_dense_tensors returns a contiguous flat array.
117
+ # http://pytorch.org/docs/master/_modules/torch/_utils.html
118
+ master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
119
+ except:
120
+ print("Error in prep_param_lists: model may contain a mixture of parameters "
121
+ "of different types. Use flat_master=False, or use F16_Optimizer.")
122
+ raise
123
+ master_params = torch.nn.Parameter(master_params)
124
+ master_params.requires_grad = True
125
+ # master_params.register_hook(backwards_debug_hook)
126
+ if master_params.grad is None:
127
+ master_params.grad = master_params.new(*master_params.size())
128
+ return model_params, [master_params]
129
+ else:
130
+ master_params = [param.clone().float().detach() for param in model_params]
131
+ for param in master_params:
132
+ param.requires_grad = True
133
+ return model_params, master_params
134
+
135
+
136
+ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
137
+ """
138
+ Copy model gradients to master gradients.
139
+
140
+ Args:
141
+ model_params: List of model parameters created by :func:`prep_param_lists`.
142
+ master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
143
+ """
144
+ if flat_master:
145
+ # The flattening may incur one more deep copy than is necessary.
146
+ master_params[0].grad.data.copy_(
147
+ _flatten_dense_tensors([p.grad.data for p in model_params]))
148
+ else:
149
+ for model, master in zip(model_params, master_params):
150
+ if model.grad is not None:
151
+ if master.grad is None:
152
+ master.grad = Variable(master.data.new(*master.data.size()))
153
+ master.grad.data.copy_(model.grad.data)
154
+ else:
155
+ master.grad = None
156
+
157
+
158
+ def master_params_to_model_params(model_params, master_params, flat_master=False):
159
+ """
160
+ Copy master parameters to model parameters.
161
+
162
+ Args:
163
+ model_params: List of model parameters created by :func:`prep_param_lists`.
164
+ master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
165
+ """
166
+ if flat_master:
167
+ for model, master in zip(model_params,
168
+ _unflatten_dense_tensors(master_params[0].data, model_params)):
169
+ model.data.copy_(master)
170
+ else:
171
+ for model, master in zip(model_params, master_params):
172
+ model.data.copy_(master.data)
173
+
174
+ # Backward compatibility fixes
175
+
176
+ def to_python_float(t):
177
+ if hasattr(t, 'item'):
178
+ return t.item()
179
+ else:
180
+ return t[0]
181
+
182
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
183
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
184
+ if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
185
+ clip_grad_norm = torch.nn.utils.clip_grad_norm
186
+ else:
187
+ clip_grad_norm = torch.nn.utils.clip_grad_norm_
apex/apex/fp16_utils/loss_scaler.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # item() is a recent addition, so this helps with backward compatibility.
4
+ def to_python_float(t):
5
+ if hasattr(t, 'item'):
6
+ return t.item()
7
+ else:
8
+ return t[0]
9
+
10
+ class LossScaler:
11
+ """
12
+ Class that manages a static loss scale. This class is intended to interact with
13
+ :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
14
+
15
+ Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
16
+ :class:`FP16_Optimizer`'s constructor.
17
+
18
+ Args:
19
+ scale (float, optional, default=1.0): The loss scale.
20
+ """
21
+
22
+ def __init__(self, scale=1):
23
+ self.cur_scale = scale
24
+
25
+ # `params` is a list / generator of torch.Variable
26
+ def has_overflow(self, params):
27
+ return False
28
+
29
+ # `x` is a torch.Tensor
30
+ def _has_inf_or_nan(x):
31
+ return False
32
+
33
+ def update_scale(self, overflow):
34
+ pass
35
+
36
+ @property
37
+ def loss_scale(self):
38
+ return self.cur_scale
39
+
40
+ def scale_gradient(self, module, grad_in, grad_out):
41
+ return tuple(self.loss_scale * g for g in grad_in)
42
+
43
+ def backward(self, loss, retain_graph=False):
44
+ scaled_loss = loss*self.loss_scale
45
+ scaled_loss.backward(retain_graph=retain_graph)
46
+
47
+ class DynamicLossScaler:
48
+ """
49
+ Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
50
+ indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
51
+ :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
52
+ operates, because the default options can be changed using the
53
+ the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
54
+
55
+ Loss scaling is designed to combat the problem of underflowing gradients encountered at long
56
+ times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
57
+ scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
58
+ encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
59
+ occurred.
60
+ :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
61
+ and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
62
+ If a certain number of iterations occur without overflowing gradients detected,
63
+ :class:`DynamicLossScaler` increases the loss scale once more.
64
+ In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
65
+ always using the highest loss scale possible without incurring overflow.
66
+
67
+ Args:
68
+ init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
69
+ scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
70
+ scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
71
+ """
72
+
73
+ def __init__(self,
74
+ init_scale=2**32,
75
+ scale_factor=2.,
76
+ scale_window=1000):
77
+ self.cur_scale = init_scale
78
+ self.cur_iter = 0
79
+ self.last_overflow_iter = -1
80
+ self.scale_factor = scale_factor
81
+ self.scale_window = scale_window
82
+
83
+ # `params` is a list / generator of torch.Variable
84
+ def has_overflow(self, params):
85
+ for p in params:
86
+ if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
87
+ return True
88
+
89
+ return False
90
+
91
+ # `x` is a torch.Tensor
92
+ def _has_inf_or_nan(x):
93
+ try:
94
+ # if x is half, the .float() incurs an additional deep copy, but it's necessary if
95
+ # Pytorch's .sum() creates a one-element tensor of the same type as x
96
+ # (which is true for some recent version of pytorch).
97
+ cpu_sum = float(x.float().sum())
98
+ # More efficient version that can be used if .sum() returns a Python scalar
99
+ # cpu_sum = float(x.sum())
100
+ except RuntimeError as instance:
101
+ # We want to check if inst is actually an overflow exception.
102
+ # RuntimeError could come from a different error.
103
+ # If so, we still want the exception to propagate.
104
+ if "value cannot be converted" not in instance.args[0]:
105
+ raise
106
+ return True
107
+ else:
108
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
109
+ return True
110
+ return False
111
+
112
+ # `overflow` is boolean indicating whether the gradient overflowed
113
+ def update_scale(self, overflow):
114
+ if overflow:
115
+ # self.cur_scale /= self.scale_factor
116
+ self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
117
+ self.last_overflow_iter = self.cur_iter
118
+ else:
119
+ if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
120
+ self.cur_scale *= self.scale_factor
121
+ self.cur_iter += 1
122
+
123
+ @property
124
+ def loss_scale(self):
125
+ return self.cur_scale
126
+
127
+ def scale_gradient(self, module, grad_in, grad_out):
128
+ return tuple(self.loss_scale * g for g in grad_in)
129
+
130
+ def backward(self, loss, retain_graph=False):
131
+ scaled_loss = loss*self.loss_scale
132
+ scaled_loss.backward(retain_graph=retain_graph)
133
+
134
+ ##############################################################
135
+ # Example usage below here -- assuming it's in a separate file
136
+ ##############################################################
137
+ """
138
+ TO-DO separate out into an example.
139
+ if __name__ == "__main__":
140
+ import torch
141
+ from torch.autograd import Variable
142
+ from dynamic_loss_scaler import DynamicLossScaler
143
+
144
+ # N is batch size; D_in is input dimension;
145
+ # H is hidden dimension; D_out is output dimension.
146
+ N, D_in, H, D_out = 64, 1000, 100, 10
147
+
148
+ # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
149
+ x = Variable(torch.randn(N, D_in), requires_grad=False)
150
+ y = Variable(torch.randn(N, D_out), requires_grad=False)
151
+
152
+ w1 = Variable(torch.randn(D_in, H), requires_grad=True)
153
+ w2 = Variable(torch.randn(H, D_out), requires_grad=True)
154
+ parameters = [w1, w2]
155
+
156
+ learning_rate = 1e-6
157
+ optimizer = torch.optim.SGD(parameters, lr=learning_rate)
158
+ loss_scaler = DynamicLossScaler()
159
+
160
+ for t in range(500):
161
+ y_pred = x.mm(w1).clamp(min=0).mm(w2)
162
+ loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
163
+ print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
164
+ print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
165
+ print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
166
+
167
+ # Run backprop
168
+ optimizer.zero_grad()
169
+ loss.backward()
170
+
171
+ # Check for overflow
172
+ has_overflow = DynamicLossScaler.has_overflow(parameters)
173
+
174
+ # If no overflow, unscale grad and update as usual
175
+ if not has_overflow:
176
+ for param in parameters:
177
+ param.grad.data.mul_(1. / loss_scaler.loss_scale)
178
+ optimizer.step()
179
+ # Otherwise, don't do anything -- ie, skip iteration
180
+ else:
181
+ print('OVERFLOW!')
182
+
183
+ # Update loss scale for next iteration
184
+ loss_scaler.update_scale(has_overflow)
185
+
186
+ """
apex/apex/multi_tensor_apply/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .multi_tensor_apply import MultiTensorApply
2
+
3
+ multi_tensor_applier = MultiTensorApply(2048*32)
4
+
apex/apex/multi_tensor_apply/multi_tensor_apply.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class MultiTensorApply(object):
4
+ available = False
5
+ warned = False
6
+
7
+ def __init__(self, chunk_size):
8
+ try:
9
+ import amp_C
10
+ MultiTensorApply.available = True
11
+ self.chunk_size = chunk_size
12
+ except ImportError as err:
13
+ MultiTensorApply.available = False
14
+ MultiTensorApply.import_err = err
15
+
16
+ def check_avail(self):
17
+ if MultiTensorApply.available == False:
18
+ raise RuntimeError(
19
+ "Attempted to call MultiTensorApply method, but MultiTensorApply "
20
+ "is not available, possibly because Apex was installed without "
21
+ "--cpp_ext --cuda_ext. Original import error message:",
22
+ MultiTensorApply.import_err)
23
+
24
+ def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
25
+ self.check_avail()
26
+
27
+ return op(self.chunk_size,
28
+ noop_flag_buffer,
29
+ tensor_lists,
30
+ *args)
apex/apex/normalization/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .fused_layer_norm import FusedLayerNorm
apex/apex/normalization/fused_layer_norm.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numbers
4
+ from torch.nn.parameter import Parameter
5
+ from torch.nn import init
6
+ from torch.nn import functional as F
7
+ import importlib
8
+
9
+ class FusedLayerNormAffineFunction(torch.autograd.Function):
10
+ def __init__(self, normalized_shape, eps=1e-6):
11
+ global fused_layer_norm_cuda
12
+ fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
13
+
14
+ self.normalized_shape = normalized_shape
15
+ self.eps = eps
16
+
17
+ def forward(self, input, weight, bias):
18
+ input_ = input.contiguous()
19
+ weight_ = weight.contiguous()
20
+ bias_ = bias.contiguous()
21
+ output, mean, invvar = fused_layer_norm_cuda.forward_affine(
22
+ input_, self.normalized_shape, weight_, bias_, self.eps)
23
+ self.save_for_backward(input_, weight_, bias_, mean, invvar)
24
+ return output
25
+
26
+ def backward(self, grad_output):
27
+ input_, weight_, bias_, mean, invvar = self.saved_tensors
28
+ grad_input = grad_weight = grad_bias = None
29
+ grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine(
30
+ grad_output.contiguous(), mean, invvar,
31
+ input_, self.normalized_shape,
32
+ weight_, bias_, self.eps)
33
+ return grad_input, grad_weight, grad_bias;
34
+
35
+ class FusedLayerNormFunction(torch.autograd.Function):
36
+ def __init__(self, normalized_shape, eps=1e-6):
37
+ global fused_layer_norm_cuda
38
+ fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
39
+ self.normalized_shape = normalized_shape
40
+ self.eps = eps
41
+
42
+ def forward(self, input):
43
+ input_ = input.contiguous()
44
+ output, mean, invvar = fused_layer_norm_cuda.forward(
45
+ input_, self.normalized_shape, self.eps)
46
+ self.save_for_backward(input_, mean, invvar)
47
+ return output
48
+
49
+ def backward(self, grad_output):
50
+ input_, mean, invvar = self.saved_tensors
51
+ grad_input = None
52
+ grad_input = fused_layer_norm_cuda.backward(
53
+ grad_output.contiguous(), mean, invvar,
54
+ input_, self.normalized_shape,
55
+ self.eps)
56
+ return grad_input
57
+
58
+ def fused_layer_norm_affine(input, normalized_shape, weight, bias, eps=1e-6):
59
+ return FusedLayerNormAffineFunction(normalized_shape,eps)(input, weight, bias)
60
+
61
+ def fused_layer_norm(input, normalized_shape, eps=1e-6):
62
+ return FusedLayerNormFunction(normalized_shape,eps)(input)
63
+
64
+ class FusedLayerNorm(torch.nn.Module):
65
+ r"""Applies Layer Normalization over a mini-batch of inputs as described in
66
+ the paper `Layer Normalization`_ .
67
+
68
+ Currently only runs on cuda() tensors.
69
+
70
+ .. math::
71
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
72
+
73
+ The mean and standard-deviation are calculated separately over the last
74
+ certain number dimensions which have to be of the shape specified by
75
+ :attr:`normalized_shape`.
76
+ :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
77
+ :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
78
+
79
+ .. note::
80
+ Unlike Batch Normalization and Instance Normalization, which applies
81
+ scalar scale and bias for each entire channel/plane with the
82
+ :attr:`affine` option, Layer Normalization applies per-element scale and
83
+ bias with :attr:`elementwise_affine`.
84
+
85
+ This layer uses statistics computed from input data in both training and
86
+ evaluation modes.
87
+
88
+ Args:
89
+ normalized_shape (int or list or torch.Size): input shape from an expected input
90
+ of size
91
+
92
+ .. math::
93
+ [* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
94
+ \times \ldots \times \text{normalized}\_\text{shape}[-1]]
95
+
96
+ If a single integer is used, it is treated as a singleton list, and this module will
97
+ normalize over the last dimension which is expected to be of that specific size.
98
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
99
+ elementwise_affine: a boolean value that when set to ``True``, this module
100
+ has learnable per-element affine parameters initialized to ones (for weights)
101
+ and zeros (for biases). Default: ``True``.
102
+
103
+ Shape:
104
+ - Input: :math:`(N, *)`
105
+ - Output: :math:`(N, *)` (same shape as input)
106
+
107
+ Examples::
108
+
109
+ >>> input = torch.randn(20, 5, 10, 10)
110
+ >>> # With Learnable Parameters
111
+ >>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
112
+ >>> # Without Learnable Parameters
113
+ >>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
114
+ >>> # Normalize over last two dimensions
115
+ >>> m = apex.normalization.FusedLayerNorm([10, 10])
116
+ >>> # Normalize over last dimension of size 10
117
+ >>> m = apex.normalization.FusedLayerNorm(10)
118
+ >>> # Activating the module
119
+ >>> output = m(input)
120
+
121
+ .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
122
+ """
123
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
124
+ super(FusedLayerNorm, self).__init__()
125
+
126
+ global fused_layer_norm_cuda
127
+ fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
128
+
129
+ if isinstance(normalized_shape, numbers.Integral):
130
+ normalized_shape = (normalized_shape,)
131
+ self.normalized_shape = torch.Size(normalized_shape)
132
+ self.eps = eps
133
+ self.elementwise_affine = elementwise_affine
134
+ if self.elementwise_affine:
135
+ self.weight = Parameter(torch.Tensor(*normalized_shape))
136
+ self.bias = Parameter(torch.Tensor(*normalized_shape))
137
+ else:
138
+ self.register_parameter('weight', None)
139
+ self.register_parameter('bias', None)
140
+ self.reset_parameters()
141
+
142
+ def reset_parameters(self):
143
+ if self.elementwise_affine:
144
+ init.ones_(self.weight)
145
+ init.zeros_(self.bias)
146
+
147
+ def forward(self, input):
148
+ if not input.is_cuda:
149
+ return F.layer_norm(
150
+ input, self.normalized_shape, self.weight, self.bias, self.eps)
151
+ if self.elementwise_affine:
152
+ return FusedLayerNormAffineFunction(self.normalized_shape,self.eps)(
153
+ input, self.weight, self.bias)
154
+ else:
155
+ return FusedLayerNormFunction(self.normalized_shape,self.eps)(
156
+ input)
157
+
158
+ def extra_repr(self):
159
+ return '{normalized_shape}, eps={eps}, ' \
160
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
apex/apex/optimizers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_adam import FusedAdam
2
+ from .fp16_optimizer import FP16_Optimizer
apex/apex/optimizers/fp16_optimizer.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
3
+
4
+ class FP16_Optimizer(object):
5
+ """
6
+ :class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
7
+ Designed only to wrap apex.optimizers.FusedAdam.
8
+ Refer to apex.fp16_utils documents for more information.
9
+
10
+ Example::
11
+
12
+ model = torch.nn.Linear(D_in, D_out).cuda().half()
13
+ optimizer = apex.optimizers.FusedAdam(model.parameters())
14
+ # Name the FP16_Optimizer instance to replace the existing optimizer
15
+ # (recommended but not required):
16
+ optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
17
+ ...
18
+ # loss.backward() becomes:
19
+ optimizer.backward(loss)
20
+ ...
21
+
22
+ Example with dynamic loss scaling::
23
+
24
+ ...
25
+ optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
26
+ # optional arg to control dynamic loss scaling behavior
27
+ # dynamic_loss_args={'scale_window' : 500})
28
+ # Usually, dynamic_loss_args is not necessary.
29
+ """
30
+
31
+ def __init__(self,
32
+ init_optimizer,
33
+ static_loss_scale=1.0,
34
+ dynamic_loss_scale=False,
35
+ dynamic_loss_args=None,
36
+ verbose=True):
37
+
38
+ # The fused optimizer does all the work. We need this layer for two reason:
39
+ # 1. maintain same user API from apex.fp16_utils
40
+ # 2. keep common stuff here in case we need to add new fused optimizer later
41
+
42
+ # differences from apex.fp16_utils:
43
+ # - assume all model params in fp16
44
+ # - assume all params requires grad
45
+ # - flat by groups, not keeping state. TODO: remove state explicitly?
46
+ # - master gard and unflat master weight never exist. TODO: a way to save out unflat master?
47
+ if not torch.cuda.is_available:
48
+ raise SystemError("Cannot use fp16 without CUDA.")
49
+ self.optimizer = init_optimizer
50
+
51
+ # param flattened by groups
52
+ self.fp16_groups = []
53
+ self.fp16_groups_flat = []
54
+ self.fp32_groups_flat = []
55
+
56
+ # loop to deal with groups
57
+ for i, param_group in enumerate(self.optimizer.param_groups):
58
+ # push this group to list before modify
59
+ self.fp16_groups.append(param_group['params'])
60
+ # init fp16 weight buffer, flattened
61
+ self.fp16_groups_flat.append(_flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]]))
62
+ # set model fp16 weight to slices of flattened buffer
63
+ updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
64
+ for p,q in zip(self.fp16_groups[i], updated_params):
65
+ p.data = q.data
66
+ # init master weight, flattened
67
+ self.fp32_groups_flat.append(self.fp16_groups_flat[i].clone().float().detach())
68
+ # modify optimizer of have flat master weight
69
+ self.fp32_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it
70
+ param_group['params'] = [self.fp32_groups_flat[i]]
71
+
72
+ # we may have a way of fusing dynamic scale. Do not support for now
73
+ if dynamic_loss_scale:
74
+ if dynamic_loss_args is not None:
75
+ raise SystemError("Do not support dynamic loss scale args for now.")
76
+ self.dynamic_loss_scale = True
77
+ self.cur_scale = 2**16
78
+ self.cur_iter = 0
79
+ self.last_overflow_iter = -1
80
+ self.scale_factor = 2
81
+ self.scale_window = 1000
82
+ else:
83
+ self.dynamic_loss_scale = False
84
+ self.cur_iter = 0
85
+ self.cur_scale = static_loss_scale
86
+ self.verbose = verbose
87
+
88
+ def zero_grad(self, set_grads_to_None=True):
89
+ """
90
+ Zero FP16 parameter grads.
91
+ """
92
+ # FP32 grad should never exist.
93
+ # For speed, set model fp16 grad to None by default
94
+ for group in self.fp16_groups:
95
+ for p in group:
96
+ if set_grads_to_None:
97
+ p.grad = None
98
+ else:
99
+ if p.grad is not None:
100
+ p.grad.detach_()
101
+ p.grad.zero_()
102
+
103
+ def _compute_grad_norm(self, fp16_grads_flat, norm_type=2):
104
+ """
105
+ Compute fp16 grad norm for later clipping(fused with update).
106
+ Internal accumulated in fp32.
107
+ Also fused in NaN check. Possibly other reduction needed for grad.
108
+
109
+ Args:
110
+ fp16_grads_flat (tensor): fp16 grad flattened
111
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
112
+ infinity norm.
113
+
114
+ Returns:
115
+ Total norm of the current fp16 gradients (viewed as a single vector).
116
+ Returns -1 if the most recently computed fp16 gradients overflowed
117
+ """
118
+ # TODO: Not most efficient with copy to cpu and sync
119
+ # only support 2-norm now
120
+ # for torch version <= 1.0.1, torch.norm with dtype will fail and fall back to cast
121
+ try:
122
+ norm = float(torch.norm(fp16_grads_flat, 2.0, dtype=torch.float32))
123
+ except TypeError as err:
124
+ norm = float(torch.norm(fp16_grads_flat.float(), 2.0))
125
+ if norm == float('inf') or norm == -float('inf') or norm != norm:
126
+ return -1
127
+ else:
128
+ return norm
129
+
130
+ def step(self, closure=None):
131
+ """
132
+ Not supporting closure.
133
+ """
134
+ # First compute norm for all group so we know if there is overflow
135
+ grads_groups_flat = []
136
+ norm_groups = []
137
+ skip = False
138
+ for i, group in enumerate(self.fp16_groups):
139
+ grads_groups_flat.append(_flatten_dense_tensors([p.grad for p in group]))
140
+ norm_groups.append(self._compute_grad_norm(grads_groups_flat[i]))
141
+ if norm_groups[i] == -1: #TODO: early break
142
+ skip = True
143
+
144
+ if skip:
145
+ self._update_scale(skip)
146
+ return
147
+
148
+ # norm is in fact norm*cur_scale
149
+ self.optimizer.step(grads=[[g] for g in grads_groups_flat],
150
+ output_params=[[p] for p in self.fp16_groups_flat],
151
+ scale=self.cur_scale,
152
+ grad_norms=norm_groups)
153
+
154
+ # TODO: we probably don't need this? just to be safe
155
+ for i in range(len(norm_groups)):
156
+ updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
157
+ for p,q in zip(self.fp16_groups[i], updated_params):
158
+ p.data = q.data
159
+
160
+ self._update_scale(False)
161
+ return
162
+
163
+ def backward(self, loss):
164
+ """
165
+ :attr:`backward` performs the following steps:
166
+
167
+ 1. fp32_loss = loss.float()
168
+ 2. scaled_loss = fp32_loss*loss_scale
169
+ 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
170
+ """
171
+ scaled_loss = (loss.float()) * self.cur_scale
172
+ scaled_loss.backward()
173
+
174
+ def _update_scale(self, skip):
175
+ if self.dynamic_loss_scale:
176
+ if skip:
177
+ if self.verbose:
178
+ print("\nGrad overflow on iteration", self.cur_iter)
179
+ print("Using dynamic loss scale of", self.cur_scale)
180
+ self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
181
+ self.last_overflow_iter = self.cur_iter
182
+ else:
183
+ if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
184
+ self.cur_scale *= self.scale_factor
185
+ else:
186
+ if skip:
187
+ print("\nGrad overflow on iteration", self.cur_iter)
188
+ print("Using static loss scale of", self.cur_scale)
189
+ self.cur_iter +=1
190
+ return
191
+
192
+ # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
193
+ def _get_state(self):
194
+ return self.optimizer.state
195
+
196
+ def _set_state(self, value):
197
+ self.optimizer.state = value
198
+
199
+ state = property(_get_state, _set_state)
200
+
201
+ # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
202
+ # (for example, to adjust the learning rate)
203
+ def _get_param_groups(self):
204
+ return self.optimizer.param_groups
205
+
206
+ def _set_param_groups(self, value):
207
+ self.optimizer.param_groups = value
208
+
209
+ param_groups = property(_get_param_groups, _set_param_groups)
210
+
211
+ def state_dict(self):
212
+ """
213
+ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
214
+ This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
215
+ of the contained Pytorch optimizer.
216
+ Example::
217
+ checkpoint = {}
218
+ checkpoint['model'] = model.state_dict()
219
+ checkpoint['optimizer'] = optimizer.state_dict()
220
+ torch.save(checkpoint, "saved.pth")
221
+ """
222
+ state_dict = {}
223
+ state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
224
+ state_dict['cur_scale'] = self.cur_scale
225
+ state_dict['cur_iter'] = self.cur_iter
226
+ if state_dict['dynamic_loss_scale']:
227
+ state_dict['last_overflow_iter'] = self.last_overflow_iter
228
+ state_dict['scale_factor'] = self.scale_factor
229
+ state_dict['scale_window'] = self.scale_window
230
+ state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
231
+ state_dict['fp32_groups_flat'] = self.fp32_groups_flat
232
+ return state_dict
233
+
234
+ def load_state_dict(self, state_dict):
235
+ """
236
+ Loads a state_dict created by an earlier call to state_dict().
237
+ If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
238
+ whose parameters in turn came from ``model``, it is expected that the user
239
+ will call ``model.load_state_dict()`` before
240
+ ``fp16_optimizer_instance.load_state_dict()`` is called.
241
+ Example::
242
+ model = torch.nn.Linear(D_in, D_out).cuda().half()
243
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
244
+ optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
245
+ ...
246
+ checkpoint = torch.load("saved.pth")
247
+ model.load_state_dict(checkpoint['model'])
248
+ optimizer.load_state_dict(checkpoint['optimizer'])
249
+ """
250
+ # I think it should actually be ok to reload the optimizer before the model.
251
+ self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
252
+ self.cur_scale = state_dict['cur_scale']
253
+ self.cur_iter = state_dict['cur_iter']
254
+ if state_dict['dynamic_loss_scale']:
255
+ self.last_overflow_iter = state_dict['last_overflow_iter']
256
+ self.scale_factor = state_dict['scale_factor']
257
+ self.scale_window = state_dict['scale_window']
258
+ self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
259
+ # At this point, the optimizer's references to the model's fp32 parameters are up to date.
260
+ # The optimizer's hyperparameters and internal buffers are also up to date.
261
+ # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
262
+ # out of date. There are two options.
263
+ # 1: Refresh the master params from the model's fp16 params.
264
+ # This requires less storage but incurs precision loss.
265
+ # 2: Save and restore the fp32 master copies separately.
266
+ # We choose option 2.
267
+ #
268
+ # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
269
+ # of their associated parameters, because it's possible those buffers might not exist yet in
270
+ # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
271
+ # constructed in the same way as the one whose state_dict we are loading, the same master params
272
+ # are guaranteed to exist, so we can just copy_() from the saved master params.
273
+ for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
274
+ current.data.copy_(saved.data)
apex/apex/optimizers/fused_adam.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import torch
3
+ import importlib
4
+
5
+ class FusedAdam(torch.optim.Optimizer):
6
+
7
+ """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
8
+ ``python setup.py install --cuda_ext --cpp_ext``.
9
+
10
+ It has been proposed in `Adam: A Method for Stochastic Optimization`_.
11
+
12
+ Arguments:
13
+ params (iterable): iterable of parameters to optimize or dicts defining
14
+ parameter groups.
15
+ lr (float, optional): learning rate. (default: 1e-3)
16
+ betas (Tuple[float, float], optional): coefficients used for computing
17
+ running averages of gradient and its square. (default: (0.9, 0.999))
18
+ eps (float, optional): term added to the denominator to improve
19
+ numerical stability. (default: 1e-8)
20
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
21
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
22
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
23
+ (default: False) NOT SUPPORTED in FusedAdam!
24
+ eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
25
+ adds eps to the bias-corrected second moment estimate before
26
+ evaluating square root instead of adding it to the square root of
27
+ second moment estimate as in the original paper. (default: False)
28
+
29
+ .. _Adam\: A Method for Stochastic Optimization:
30
+ https://arxiv.org/abs/1412.6980
31
+ .. _On the Convergence of Adam and Beyond:
32
+ https://openreview.net/forum?id=ryQu7f-RZ
33
+ """
34
+
35
+ def __init__(self, params,
36
+ lr=1e-3, bias_correction = True,
37
+ betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
38
+ weight_decay=0., max_grad_norm=0., amsgrad=False):
39
+ global fused_adam_cuda
40
+ fused_adam_cuda = importlib.import_module("fused_adam_cuda")
41
+
42
+ if amsgrad:
43
+ raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
44
+ defaults = dict(lr=lr, bias_correction=bias_correction,
45
+ betas=betas, eps=eps, weight_decay=weight_decay,
46
+ max_grad_norm=max_grad_norm)
47
+ super(FusedAdam, self).__init__(params, defaults)
48
+ self.eps_mode = 0 if eps_inside_sqrt else 1
49
+
50
+ def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
51
+ """Performs a single optimization step.
52
+
53
+ Arguments:
54
+ closure (callable, optional): A closure that reevaluates the model
55
+ and returns the loss.
56
+ grads (list of tensors, optional): weight gradient to use for the
57
+ optimizer update. If gradients have type torch.half, parameters
58
+ are expected to be in type torch.float. (default: None)
59
+ output params (list of tensors, optional): A reduced precision copy
60
+ of the updated weights written out in addition to the regular
61
+ updated weights. Have to be of same type as gradients. (default: None)
62
+ scale (float, optional): factor to divide gradient tensor values
63
+ by before applying to weights. (default: 1)
64
+ """
65
+ loss = None
66
+ if closure is not None:
67
+ loss = closure()
68
+
69
+ if grads is None:
70
+ grads_group = [None]*len(self.param_groups)
71
+ # backward compatibility
72
+ # assuming a list/generator of parameter means single group
73
+ elif isinstance(grads, types.GeneratorType):
74
+ grads_group = [grads]
75
+ elif type(grads[0])!=list:
76
+ grads_group = [grads]
77
+ else:
78
+ grads_group = grads
79
+
80
+ if output_params is None:
81
+ output_params_group = [None]*len(self.param_groups)
82
+ elif isinstance(output_params, types.GeneratorType):
83
+ output_params_group = [output_params]
84
+ elif type(output_params[0])!=list:
85
+ output_params_group = [output_params]
86
+ else:
87
+ output_params_group = output_params
88
+
89
+ if grad_norms is None:
90
+ grad_norms = [None]*len(self.param_groups)
91
+
92
+ for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):
93
+ if grads_this_group is None:
94
+ grads_this_group = [None]*len(group['params'])
95
+ if output_params_this_group is None:
96
+ output_params_this_group = [None]*len(group['params'])
97
+
98
+ # compute combined scale factor for this group
99
+ combined_scale = scale
100
+ if group['max_grad_norm'] > 0:
101
+ # norm is in fact norm*scale
102
+ clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
103
+ if clip > 1:
104
+ combined_scale = clip * scale
105
+
106
+ bias_correction = 1 if group['bias_correction'] else 0
107
+
108
+ for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
109
+ #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
110
+ if p.grad is None and grad is None:
111
+ continue
112
+ if grad is None:
113
+ grad = p.grad.data
114
+ if grad.is_sparse:
115
+ raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
116
+
117
+ state = self.state[p]
118
+
119
+ # State initialization
120
+ if len(state) == 0:
121
+ state['step'] = 0
122
+ # Exponential moving average of gradient values
123
+ state['exp_avg'] = torch.zeros_like(p.data)
124
+ # Exponential moving average of squared gradient values
125
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
126
+
127
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
128
+ beta1, beta2 = group['betas']
129
+
130
+ state['step'] += 1
131
+
132
+ out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
133
+ fused_adam_cuda.adam(p.data,
134
+ out_p,
135
+ exp_avg,
136
+ exp_avg_sq,
137
+ grad,
138
+ group['lr'],
139
+ beta1,
140
+ beta2,
141
+ group['eps'],
142
+ combined_scale,
143
+ state['step'],
144
+ self.eps_mode,
145
+ bias_correction,
146
+ group['weight_decay'])
147
+ return loss
apex/apex/parallel/LARC.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.autograd import Variable
4
+ from torch.nn.parameter import Parameter
5
+
6
+ class LARC(object):
7
+ """
8
+ :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC,
9
+ in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive
10
+ local learning rate for each individual parameter. The algorithm is designed to improve
11
+ convergence of large batch training.
12
+
13
+ See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate.
14
+
15
+ In practice it modifies the gradients of parameters as a proxy for modifying the learning rate
16
+ of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer.
17
+
18
+ ```
19
+ model = ...
20
+ optim = torch.optim.Adam(model.parameters(), lr=...)
21
+ optim = LARC(optim)
22
+ ```
23
+
24
+ It can even be used in conjunction with apex.fp16_utils.FP16_optimizer.
25
+
26
+ ```
27
+ model = ...
28
+ optim = torch.optim.Adam(model.parameters(), lr=...)
29
+ optim = LARC(optim)
30
+ optim = apex.fp16_utils.FP16_Optimizer(optim)
31
+ ```
32
+
33
+ Args:
34
+ optimizer: Pytorch optimizer to wrap and modify learning rate for.
35
+ trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888
36
+ clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`.
37
+ eps: epsilon kludge to help with numerical stability while calculating adaptive_lr
38
+ """
39
+
40
+ def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
41
+ self.param_groups = optimizer.param_groups
42
+ self.optim = optimizer
43
+ self.trust_coefficient = trust_coefficient
44
+ self.eps = eps
45
+ self.clip = clip
46
+
47
+ def __getstate__(self):
48
+ return self.optim.__getstate__()
49
+
50
+ def __setstate__(self, state):
51
+ self.optim.__setstate__(state)
52
+
53
+ def __repr__(self):
54
+ return self.optim.__repr__()
55
+
56
+ def state_dict(self):
57
+ return self.optim.state_dict()
58
+
59
+ def load_state_dict(self, state_dict):
60
+ self.optim.load_state_dict(state_dict)
61
+
62
+ def zero_grad(self):
63
+ self.optim.zero_grad()
64
+
65
+ def add_param_group(self, param_group):
66
+ self.optim.add_param_group( param_group)
67
+
68
+ def step(self):
69
+ with torch.no_grad():
70
+ weight_decays = []
71
+ for group in self.optim.param_groups:
72
+ # absorb weight decay control from optimizer
73
+ weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
74
+ weight_decays.append(weight_decay)
75
+ group['weight_decay'] = 0
76
+ for p in group['params']:
77
+ if p.grad is None:
78
+ continue
79
+ param_norm = torch.norm(p.data)
80
+ grad_norm = torch.norm(p.grad.data)
81
+
82
+ if param_norm != 0 and grad_norm != 0:
83
+ # calculate adaptive lr + weight decay
84
+ adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)
85
+
86
+ # clip learning rate for LARC
87
+ if self.clip:
88
+ # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
89
+ adaptive_lr = min(adaptive_lr/group['lr'], 1)
90
+
91
+ p.grad.data += weight_decay * p.data
92
+ p.grad.data *= adaptive_lr
93
+
94
+ self.optim.step()
95
+ # return weight decay control to optimizer
96
+ for i, group in enumerate(self.optim.param_groups):
97
+ group['weight_decay'] = weight_decays[i]
apex/apex/parallel/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Distributed Data Parallel
2
+
3
+ distributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library.
4
+
5
+ `apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with
6
+ computation in the backward pass and bucketing smaller transfers to reduce the total number of
7
+ transfers required.
8
+
9
+ multiproc.py contains the source code for `apex.parallel.multiproc`, a launch utility that places one process on each of the node's available GPUs.
10
+
11
+ #### [API Documentation](https://nvidia.github.io/apex/parallel.html)
12
+
13
+ #### [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed)
14
+
15
+ #### [Imagenet example with Mixed Precision](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
16
+
17
+ #### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex)
18
+
19
+ ### Synchronized Batch Normalization
20
+
21
+ `apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`.
22
+ It reduces stats on the first (channel) dimension of the Tensor and accepts
23
+ arbitrary spatial dimensions.
24
+
25
+ #### Installation
26
+
27
+ Apex provides two sync BN implementation:
28
+
29
+ 1. There is the Python-only implementation, which is the default implementation
30
+ when install with `python setup.py install`.
31
+ It uses PyTorch primitive operations and distributed communication package from
32
+ `torch.distributed`.
33
+
34
+ - _Python-only implementation requires input tensor to be of same data type as
35
+ layer_
36
+
37
+ 2. We also provide implementation with kernels through CUDA/C++ extension with
38
+ improved performance. We are experimenting with Welford and Kahan for reduction
39
+ hoping to get better accuracy.
40
+ To use the kernel implementation, user need to install Apex with CUDA extension
41
+ enabled `python setup.py install --cuda_ext`.
42
+
43
+ - _Custom kernel implementation supports fp16 input with fp32 layer as cudnn.
44
+ This is required to run imagenet example in fp16._
45
+
46
+ - _Currently kernel implementation only supports GPU._
47
+
48
+ #### HowTo
49
+
50
+ 1. User could use `apex.parallel.SyncBatchNorm` by building their module with
51
+ the layer explicitly.
52
+
53
+ ```
54
+ import apex
55
+ input_t = torch.randn(3, 5, 20).cuda()
56
+ sbn = apex.parallel.SyncBatchNorm(5).cuda()
57
+ output_t = sbn(input)
58
+ ```
59
+
60
+ 2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`.
61
+
62
+ ```
63
+ # model is an instance of torch.nn.Module
64
+ import apex
65
+ sync_bn_model = apex.parallel.convert_syncbn_model(model)
66
+ ```