Coverage for binned_cdf/piecewise_linear_binned_cdf.py: 96%

66 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-08 12:02 +0000

1import torch 

2 

3from .piecewise_constant_binned_cdf import PiecewiseConstantBinnedCDF 

4 

5 

6class PiecewiseLinearBinnedCDF(PiecewiseConstantBinnedCDF): 

7 """A continuous probability distribution parameterized by binned logits for the CDF. 

8 

9 Unlike [PiecewiseConstantBinnedCDF][binned_cdf.piecewise_constant_cdf.PiecewiseConstantBinnedCDF], which evaluates 

10 the CDF as a step function over bin centers, this class implements a true piecewise-linear CDF, i.e., histogram PDF, 

11 interpolating smoothly between bin edges. 

12 """ 

13 

14 @property 

15 def variance(self) -> torch.Tensor: 

16 """Compute variance of the distribution. 

17 

18 Note: 

19 Since the distribution is piecewise linear, the variance includes both the discrete variance from the 

20 bin probabilities and the intra-bin variance due to linear interpolation called Sheppard's correction, 

21 which assumes that probabilities are uniformly distributed within each bin. 

22 

23 Returns: 

24 Tensor of shape (*batch_shape,). 

25 """ 

26 discrete_var = super().variance 

27 intra_bin_var = torch.sum(self.bin_probs * (self.bin_widths**2) / 12.0, dim=-1) # Sheppard's correction 

28 return discrete_var + intra_bin_var 

29 

30 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 

31 """Compute the log-probability density at given values. 

32 

33 Args: 

34 value: Values at which to compute the log-PDF. 

35 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

36 

37 Returns: 

38 Log-PDF values corresponding to the input values. 

39 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

40 """ 

41 value_prep, num_sample_dims = self._prepare_input(value) 

42 

43 # Compute the log of the probability mass for the bin the value falls into. 

44 log_mass = super().log_prob(value) # also validates the args if self._validate_args is True 

45 

46 # We need to gather the width of the bin the value falls into. 

47 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

48 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

49 

50 # Log density = log(mass / width) = log_mass - log_width. 

51 eps = torch.finfo(widths.dtype).eps 

52 log_prob = log_mass - torch.log(widths + 2 * eps) 

53 

54 return log_prob 

55 

56 def prob(self, value: torch.Tensor) -> torch.Tensor: 

57 """Compute probability density at given values. 

58 

59 Args: 

60 value: Values at which to compute the PDF. 

61 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

62 

63 Returns: 

64 PDF values corresponding to the input values. 

65 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

66 """ 

67 if self._validate_args: 67 ↛ 70line 67 didn't jump to line 70 because the condition on line 67 was always true

68 self._validate_sample(value) 

69 

70 value_prep, num_sample_dims = self._prepare_input(value) 

71 

72 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

73 

74 # Gather normalized mass and bin width. 

75 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

76 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

77 

78 # Density = p(bin_i) / width_i. 

79 eps = torch.finfo(widths.dtype).eps 

80 prob = masses / (widths + 2 * eps) 

81 

82 return prob 

83 

84 def cdf(self, value: torch.Tensor) -> torch.Tensor: 

85 """Compute cumulative distribution function at given values. 

86 

87 Args: 

88 value: Values at which to compute the CDF. 

89 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

90 

91 Returns: 

92 CDF values in [0, 1] corresponding to the input values. 

93 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

94 """ 

95 if self._validate_args: 95 ↛ 98line 95 didn't jump to line 98 because the condition on line 95 was always true

96 self._validate_sample(value) 

97 

98 value_prep, num_sample_dims = self._prepare_input(value) 

99 

100 # Find the bin in probability space. 

101 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

102 

103 # Gather the interpolation parameters. 

104 left_edges = self._gather_from_bins( 

105 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape 

106 ) 

107 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

108 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

109 

110 # Get base CDF at the left edge of the bin. 

111 # Prepend 0 for the case where no bins are active. 

112 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) 

113 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device) 

114 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1) 

115 base_cdf = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

116 

117 # Interpolate: cdf = base_cdf + (x_input - x_left_edge) * (mass / width) 

118 eps = torch.finfo(widths.dtype).eps 

119 alpha = (value_prep - left_edges) / (widths + 2 * eps) 

120 alpha = torch.clamp(alpha, 0.0, 1.0) # prevent extrapolation 

121 cdf_vals = base_cdf + alpha * masses 

122 

123 return cdf_vals 

124 

125 def icdf(self, value: torch.Tensor) -> torch.Tensor: 

126 """Compute the inverse CDF, i.e., the quantile function, at the given values. 

127 

128 Args: 

129 value: Values in [0, 1] at which to compute the inverse CDF. 

130 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

131 

132 Returns: 

133 Quantiles in [bound_low, bound_up] corresponding to the input CDF values. 

134 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

135 """ 

136 if self._validate_args: 136 ↛ 139line 136 didn't jump to line 139 because the condition on line 136 was always true

137 self._validate_sample(value) 

138 

139 value_prep, num_sample_dims = self._prepare_input(value) 

140 

141 # Get the CDF edges (y-coordinates of the piecewise linear segments). 

142 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device) 

143 cdf_edges = torch.cat([zero_prefix, torch.cumsum(self.bin_probs, dim=-1)], dim=-1) 

144 

145 # Find the bin in probability space. 

146 cdf_edges_aligned = ( 

147 cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape).expand(*value_prep.shape, -1).contiguous() 

148 ) 

149 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_aligned) 

150 

151 # Gather the probability base. 

152 base_cdf = self._gather_from_bins(cdf_edges, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

153 

154 # Gather the interpolation parameters. 

155 left_edges = self._gather_from_bins( 

156 self.bin_edges[:-1], bin_indices, num_sample_dims, target_shape=value_prep.shape 

157 ) 

158 widths = self._gather_from_bins(self.bin_widths, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

159 masses = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

160 

161 # Interpolate: x = x_left_edge + (target_cdf - base_cdf) * (width / mass) 

162 eps = torch.finfo(masses.dtype).eps 

163 slope = widths / (masses + 2 * eps) # add eps to avoid division by zero for bins with no mass 

164 interp_value = left_edges + (value_prep - base_cdf) * slope 

165 

166 quantiles = torch.clamp(interp_value, self.bound_low, self.bound_up) 

167 

168 return quantiles 

169 

170 def entropy(self) -> torch.Tensor: 

171 r"""Compute differential entropy of the distribution. 

172 

173 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) ) 

174 

175 Note: 

176 Here, we are doing an approximation by treating each bin as a uniform distribution over its width. 

177 

178 Returns: 

179 Tensor of shape (*batch_shape,). 

180 """ 

181 bin_probs = self.bin_probs 

182 

183 # Get the PDF values at bin centers. 

184 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins) 

185 

186 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i. 

187 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability 

188 entropy_per_bin = -bin_probs * log_pdf 

189 

190 # Sum over bins to get total entropy. 

191 return torch.sum(entropy_per_bin, dim=-1)