Coverage for binned_cdf / piecewise_linear_binned_cdf.py: 96%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-09 09:21 +0000

1import torch 

2 

3from .piecewise_constant_binned_cdf import PiecewiseConstantBinnedCDF 

4 

5 

6class PiecewiseLinearBinnedCDF(PiecewiseConstantBinnedCDF): 

7 """A continuous probability distribution parameterized by binned logits for the CDF. 

8 

9 Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates 

10 the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF, 

11 interpolating smoothly between bin edges. 

12 """ 

13 

14 @property 

15 def variance(self) -> torch.Tensor: 

16 """Compute variance of the distribution, of shape (*batch_shape,). 

17 

18 Note: 

19 Since the distribution is piecewise linear, the variance includes both the discrete variance from the 

20 bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction, 

21 which assumes that probabilities are uniformly distributed within each bin. 

22 """ 

23 discrete_var = super().variance 

24 intra_bin_var = torch.sum(self.bin_probs * (self.bin_widths**2) / 12.0, dim=-1) # Sheppard's correction 

25 return discrete_var + intra_bin_var 

26 

27 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 

28 """Compute the log-probability density at given values. 

29 

30 Args: 

31 value: Values at which to compute the log PDF. 

32 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

33 

34 Returns: 

35 Log PDF values corresponding to the input values. 

36 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

37 """ 

38 value_prep, num_sample_dims = self._prepare_input(value) 

39 

40 # Compute the log of the probability mass for the bin the value falls into. 

41 log_mass = super().log_prob(value) # also validates the args if self._validate_args is True 

42 

43 # We need to gather the width of the bin the value falls into. 

44 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

45 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

46 

47 # Log density = log(mass / width) = log_mass - log_width. 

48 eps = torch.finfo(widths.dtype).eps 

49 log_prob = log_mass - torch.log(widths + 2 * eps) 

50 

51 return log_prob 

52 

53 def prob(self, value: torch.Tensor) -> torch.Tensor: 

54 """Compute probability density at given values. 

55 

56 Args: 

57 value: Values at which to compute the PDF. 

58 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

59 

60 Returns: 

61 PDF values corresponding to the input values. 

62 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

63 """ 

64 if self._validate_args: 64 ↛ 67line 64 didn't jump to line 67 because the condition on line 64 was always true

65 self._validate_sample(value) 

66 

67 value_prep, num_sample_dims = self._prepare_input(value) 

68 

69 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

70 

71 # Gather normalized mass and bin width. 

72 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

73 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

74 

75 # Density = p(bin_i) / width_i. 

76 eps = torch.finfo(widths.dtype).eps 

77 prob = masses / (widths + 2 * eps) 

78 

79 return prob 

80 

81 def cdf(self, value: torch.Tensor) -> torch.Tensor: 

82 """Compute cumulative distribution function at given values. 

83 

84 Args: 

85 value: Values at which to compute the CDF. 

86 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

87 

88 Returns: 

89 CDF values in [0, 1] corresponding to the input values. 

90 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

91 """ 

92 if self._validate_args: 92 ↛ 95line 92 didn't jump to line 95 because the condition on line 92 was always true

93 self._validate_sample(value) 

94 

95 value_prep, num_sample_dims = self._prepare_input(value) 

96 

97 # Find the bin in probability space. 

98 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

99 

100 # Gather the interpolation parameters. 

101 left_edges = self._gather_from_bins( 

102 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape 

103 ) 

104 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

105 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

106 

107 # Get base CDF at the left edge of the bin. 

108 # Prepend 0 for the case where no bins are active. 

109 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) 

110 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device) 

111 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1) 

112 base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

113 

114 # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width) 

115 eps = torch.finfo(widths.dtype).eps 

116 alpha = (value_prep - left_edges) / (widths + 2 * eps) 

117 alpha = torch.clamp(alpha, 0.0, 1.0) # prevent extrapolation 

118 cdf_vals = base_cdf + alpha * masses 

119 

120 return cdf_vals 

121 

122 def icdf(self, value: torch.Tensor) -> torch.Tensor: 

123 """Compute the inverse CDF, i.e., the quantile function, at the given values. 

124 

125 Args: 

126 value: Values in [0, 1] at which to compute the inverse CDF. 

127 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

128 

129 Returns: 

130 Quantiles in [bound_low, bound_up] corresponding to the input CDF values. 

131 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

132 """ 

133 if self._validate_args: 133 ↛ 136line 133 didn't jump to line 136 because the condition on line 133 was always true

134 self._validate_sample(value) 

135 

136 value_prep, num_sample_dims = self._prepare_input(value) 

137 

138 # Get the CDF edges (y-coordinates of the piecewise linear segments). 

139 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device) 

140 cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1) 

141 

142 # Find the bin in probability space. 

143 cdf_edges_aligned = ( 

144 cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous() 

145 ) 

146 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned) 

147 

148 # Gather the probability base. 

149 base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

150 

151 # Gather the interpolation parameters. 

152 left_edges = self._gather_from_bins( 

153 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape 

154 ) 

155 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

156 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

157 

158 # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass) 

159 eps = torch.finfo(masses.dtype).eps 

160 slope = widths / (masses + 2 * eps) # add eps to avoid division by zero for bins with no mass 

161 interp_value = left_edges + (value_prep - base_cdf) * slope 

162 

163 quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up) 

164 

165 return quantiles 

166 

167 def entropy(self) -> torch.Tensor: 

168 r"""Compute differential entropy of the distribution. 

169 

170 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) ) 

171 

172 Note: 

173 Here, we are doing an approximation by treating each bin as a uniform distribution over its width. 

174 """ 

175 bin_probs = self.bin_probs 

176 

177 # Get the PDF values at bin centers. 

178 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins) 

179 

180 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i. 

181 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability 

182 entropy_per_bin = -bin_probs * log_pdf 

183 

184 # Sum over bins to get total entropy. 

185 return torch.sum(entropy_per_bin, dim=-1)