Mixed-precision arithmetic explained
Mixed-precision arithmetic is a form of floating-point arithmetic that uses numbers with varying widths in a single operation.
Overview
A common usage of mixed-precision arithmetic is for operating on inaccurate numbers with a small width and expanding them to a larger, more accurate representation. For example, two half-precision or bfloat16 (16-bit) floating-point numbers may be multiplied together to result in a more accurate single-precision (32-bit) float.[1] In this way, mixed-precision arithmetic approximates arbitrary-precision arithmetic, albeit with a low number of possible precisions.
Iterative algorithms (like gradient descent) are good candidates for mixed-precision arithmetic. In an iterative algorithm like square root, a coarse integral guess can be made and refined over many iterations until the error in precision makes it such that the smallest addition or subtraction to the guess is still too coarse to be an acceptable answer. When this happens, the precision can be increased to something more precise, which allows for smaller increments to be used for the approximation.
Supercomputers such as Summit utilize mixed-precision arithmetic to be more efficient with regards to memory and processing time, as well as power consumption.[1] [2] [3]
Floating point format
A floating-point number is typically packed into a single bit-string, as the sign bit, the exponent field, and the significand or mantissa, from left to right. As an example, a IEEE 754 standard 32-bit float ("FP32", "float32", or "binary32") is packed as follows:
The IEEE 754 binary floats are:
Type | Bits | | Exponentbias | Bitsprecision | Number ofdecimal digits |
---|
Sign | Exponent | Significand | Total |
---|
Half (IEEE 754-2008) | 1 | 5 | 10 | 16 | 15 | 11 | ~3.3 |
Single | 1 | 8 | 23 | 32 | 127 | 24 | ~7.2 |
Double | 1 | 11 | 52 | 64 | 1023 | 53 | ~15.9 |
x86 extended precision | 1 | 15 | 64 | 80 | 16383 | 64 | ~19.2 |
Quad | 1 | 15 | 112 | 128 | 16383 | 113 | ~34.0 | |
Machine learning
Mixed-precision arithmetic is used in the field of machine learning, since gradient descent algorithms can use coarse and efficient half-precision floats for certain tasks, but can be more accurate if they use more precise but slower single-precision floats. Some platforms, including Nvidia, Intel, and AMD CPUs and GPUs, provide mixed-precision arithmetic for this purpose, using coarse floats when possible, but expanding them to higher precision when necessary.[4] [5]
Automatic mixed precision
PyTorch implements automatic mixed-precision (AMP), which performs autocasting, gradient scaling, and loss scaling.[6] [7]
- The weights are stored in a master copy at a high precision, usually in FP32.
- Autocasting means automatically converting a floating-point number between different precisions, such as from FP32 to FP16, during training. For example, matrix multiplications can often be performed in FP16 without loss of accuracy, even if the master copy weights are stored in FP32. Low-precision weights are used during forward pass.
- Gradient scaling means multiplying gradients by a constant factor during training, typically before the weight optimizer update. This is done to prevent the gradients from underflowing to zero when using low-precision data types like FP16. Mathematically, if the unscaled gradient is
, the scaled gradient is
where
is the scaling factor. Within the optimizer update, the scaled gradient is cast to a higher precision before it is scaled down (no longer underflowing, as it is in a higher precision) to update the weights.
- Loss scaling means multiplying the loss function by a constant factor during training, typically before backpropagation. This is done to prevent the gradients from underflowing to zero when using low-precision data types. If the unscaled loss is
, the scaled loss is
where
is the scaling factor. Since gradient scaling and loss scaling are mathematically equivalent by
| \partial(kl{L |
)}{\partialw |
} = k\frac, loss scaling is an implementation of gradient scaling.PyTorch AMP uses
exponential backoff to automatically adjust the scale factor for loss scaling. That is, it periodically increase the scale factor. Whenever the gradients contain a
NaN (indicating overflow), the weight update is skipped, and the scale factor is decreased.
Notes and References
- Web site: Difference Between Single-, Double-, Multi-, Mixed-Precision. NVIDIA Blog. 15 November 2019 . 30 December 2020.
- 2007.06674 . cs.MS . Ahmad . Abdelfattah . Hartwig . Anzt . A Survey of Numerical Methods Utilizing Mixed Precision Arithmetic . Boman . Erik G. . Carson . Erin . Cojean . Terry . Dongarra . Jack . Gates . Mark . Grützmacher . Thomas . Higham . Nicholas J. . Li . Sherry . Lindquist . Neil . Liu . Yang . Loe . Jennifer . Luszczek . Piotr . Nayak . Pratik . Pranesh . Sri . Rajamanickam . Siva . Ribizel . Tobias . Smith . Barry . Swirydowicz . Kasia . Thomas . Stephen . Tomov . Stanimire . Tsai . Yaohung M. . Yamazaki . Ichitaro . Urike Meier Yang . Ulrike Meier Yang . 2020.
- Web site: Holt . Kris . The US again has the world's most powerful supercomputer . Engadget . 8 June 2018 . 20 July 2018.
- Micikevicius . Paulius . Mixed Precision Training . 2018-02-15 . 1710.03740 . Narang . Sharan . Alben . Jonah . Diamos . Gregory . Elsen . Erich . Garcia . David . Ginsburg . Boris . Houston . Michael . Kuchaiev . Oleksii. cs.AI .
- Web site: 2017-10-11 . Mixed-Precision Training of Deep Neural Networks . 2024-09-10 . NVIDIA Technical Blog . en-US.
- Web site: Mixed Precision — PyTorch Training Performance Guide . 2024-09-10 . residentmario.github.io.
- Web site: What Every User Should Know About Mixed Precision Training in PyTorch . 2024-09-10 . PyTorch . en.