stochastic_depth.py
| 1 | # Implementation modified from torchvision: |
| 2 | # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py |
| 3 | # |
| 4 | # License: |
| 5 | # BSD 3-Clause License |
| 6 | # |
| 7 | # Copyright (c) Soumith Chintala 2016, |
| 8 | # All rights reserved. |
| 9 | # |
| 10 | # Redistribution and use in source and binary forms, with or without |
| 11 | # modification, are permitted provided that the following conditions are met: |
| 12 | # |
| 13 | # * Redistributions of source code must retain the above copyright notice, this |
| 14 | # list of conditions and the following disclaimer. |
| 15 | # |
| 16 | # * Redistributions in binary form must reproduce the above copyright notice, |
| 17 | # this list of conditions and the following disclaimer in the documentation |
| 18 | # and/or other materials provided with the distribution. |
| 19 | # |
| 20 | # * Neither the name of the copyright holder nor the names of its |
| 21 | # contributors may be used to endorse or promote products derived from |
| 22 | # this software without specific prior written permission. |
| 23 | # |
| 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 34 | |
| 35 | import torch |
| 36 | import torch.fx |
| 37 | from torch import nn, Tensor |
| 38 | |
| 39 | |
| 40 | def stochastic_depth( |
| 41 | input: Tensor, p: float, mode: str, training: bool = True |
| 42 | ) -> Tensor: |
| 43 | """ |
| 44 | Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" |
| 45 | <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual |
| 46 | branches of residual architectures. |
| 47 | |
| 48 | Args: |
| 49 | input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one |
| 50 | being its batch i.e. a batch with ``N`` rows. |
| 51 | p (float): probability of the input to be zeroed. |
| 52 | mode (str): ``"batch"`` or ``"row"``. |
| 53 | ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes |
| 54 | randomly selected rows from the batch. |
| 55 | training: apply stochastic depth if is ``True``. Default: ``True`` |
| 56 | |
| 57 | Returns: |
| 58 | Tensor[N, ...]: The randomly zeroed tensor. |
| 59 | """ |
| 60 | if p < 0.0 or p > 1.0: |
| 61 | raise ValueError(f"drop probability has to be between 0 and 1, but got {p}") |
| 62 | if mode not in ["batch", "row"]: |
| 63 | raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}") |
| 64 | if not training or p == 0.0: |
| 65 | return input |
| 66 | |
| 67 | survival_rate = 1.0 - p |
| 68 | if mode == "row": |
| 69 | size = [input.shape[0]] + [1] * (input.ndim - 1) |
| 70 | else: |
| 71 | size = [1] * input.ndim |
| 72 | noise = torch.empty(size, dtype=input.dtype, device=input.device) |
| 73 | noise = noise.bernoulli_(survival_rate) |
| 74 | if survival_rate > 0.0: |
| 75 | noise.div_(survival_rate) |
| 76 | return input * noise |
| 77 | |
| 78 | |
| 79 | torch.fx.wrap("stochastic_depth") |
| 80 | |
| 81 | |
| 82 | class StochasticDepth(nn.Module): |
| 83 | """ |
| 84 | See :func:`stochastic_depth`. |
| 85 | """ |
| 86 | |
| 87 | def __init__(self, p: float, mode: str) -> None: |
| 88 | super().__init__() |
| 89 | self.p = p |
| 90 | self.mode = mode |
| 91 | |
| 92 | def forward(self, input: Tensor) -> Tensor: |
| 93 | return stochastic_depth(input, self.p, self.mode, self.training) |
| 94 | |
| 95 | def __repr__(self) -> str: |
| 96 | s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})" |
| 97 | return s |
| 98 | |