stochastic_depth.py
3.7 KB · 98 lines · python Raw
1 # Implementation modified from torchvision:
2 # https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
3 #
4 # License:
5 # BSD 3-Clause License
6 #
7 # Copyright (c) Soumith Chintala 2016,
8 # All rights reserved.
9 #
10 # Redistribution and use in source and binary forms, with or without
11 # modification, are permitted provided that the following conditions are met:
12 #
13 # * Redistributions of source code must retain the above copyright notice, this
14 # list of conditions and the following disclaimer.
15 #
16 # * Redistributions in binary form must reproduce the above copyright notice,
17 # this list of conditions and the following disclaimer in the documentation
18 # and/or other materials provided with the distribution.
19 #
20 # * Neither the name of the copyright holder nor the names of its
21 # contributors may be used to endorse or promote products derived from
22 # this software without specific prior written permission.
23 #
24 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34
35 import torch
36 import torch.fx
37 from torch import nn, Tensor
38
39
40 def stochastic_depth(
41 input: Tensor, p: float, mode: str, training: bool = True
42 ) -> Tensor:
43 """
44 Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
45 <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
46 branches of residual architectures.
47
48 Args:
49 input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
50 being its batch i.e. a batch with ``N`` rows.
51 p (float): probability of the input to be zeroed.
52 mode (str): ``"batch"`` or ``"row"``.
53 ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
54 randomly selected rows from the batch.
55 training: apply stochastic depth if is ``True``. Default: ``True``
56
57 Returns:
58 Tensor[N, ...]: The randomly zeroed tensor.
59 """
60 if p < 0.0 or p > 1.0:
61 raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
62 if mode not in ["batch", "row"]:
63 raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
64 if not training or p == 0.0:
65 return input
66
67 survival_rate = 1.0 - p
68 if mode == "row":
69 size = [input.shape[0]] + [1] * (input.ndim - 1)
70 else:
71 size = [1] * input.ndim
72 noise = torch.empty(size, dtype=input.dtype, device=input.device)
73 noise = noise.bernoulli_(survival_rate)
74 if survival_rate > 0.0:
75 noise.div_(survival_rate)
76 return input * noise
77
78
79 torch.fx.wrap("stochastic_depth")
80
81
82 class StochasticDepth(nn.Module):
83 """
84 See :func:`stochastic_depth`.
85 """
86
87 def __init__(self, p: float, mode: str) -> None:
88 super().__init__()
89 self.p = p
90 self.mode = mode
91
92 def forward(self, input: Tensor) -> Tensor:
93 return stochastic_depth(input, self.p, self.mode, self.training)
94
95 def __repr__(self) -> str:
96 s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
97 return s
98