Coverage for binned_cdf / binned_logit_cdf.py: 96%

140 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-16 05:35 +0000

1import math 

2from typing import Literal 

3 

4import torch 

5from torch.distributions import Distribution, constraints 

6 

7_size = torch.Size() 

8 

9 

10class BinnedLogitCDF(Distribution): 

11 """A histogram-based probability distribution parameterized by a bins for the CDF. 

12 

13 Each bin contributes a step function to the CDF when active. 

14 The activation of each bin is determined by applying a sigmoid to the corresponding logit. 

15 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing. 

16 

17 Note: 

18 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but 

19 not through the inputs of the `prob` or `cfg` method. 

20 """ 

21 

22 def __init__( 

23 self, 

24 logits: torch.Tensor, 

25 bound_low: float = -1e3, 

26 bound_up: float = 1e3, 

27 log_spacing: bool = False, 

28 bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid", 

29 validate_args: bool | None = None, 

30 ) -> None: 

31 """Initializer. 

32 

33 Args: 

34 logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins) 

35 bound_low: Lower bound of the distribution support, needs to be finite. 

36 bound_up: Upper bound of the distribution support, needs to be finite. 

37 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used. 

38 bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid", 

39 each bin is independently activated, while with "softmax", the bins activations influence each other. 

40 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class. 

41 """ 

42 self.logits = logits 

43 self.bound_low = bound_low 

44 self.bound_up = bound_up 

45 self.bin_normalization_method = bin_normalization_method 

46 self.log_spacing = log_spacing 

47 

48 # Create bin structure (same for all batch dimensions). 

49 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins( 

50 num_bins=logits.shape[-1], 

51 bound_low=bound_low, 

52 bound_up=bound_up, 

53 log_spacing=log_spacing, 

54 device=logits.device, 

55 dtype=logits.dtype, 

56 ) 

57 

58 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args) 

59 

60 @classmethod 

61 def _create_bins( 

62 cls, 

63 num_bins: int, 

64 bound_low: float, 

65 bound_up: float, 

66 log_spacing: bool, 

67 device: torch.device, 

68 dtype: torch.dtype, 

69 log_min_positive_edge: float = 1e-6, 

70 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

71 """Create bin edges with symmetric log spacing around zero. 

72 

73 Args: 

74 num_bins: Number of bins to create. 

75 bound_low: Lower bound of the distribution support. 

76 bound_up: Upper bound of the distribution support. 

77 log_spacing: Whether to use logarithmic spacing. 

78 device: Device for the tensors. 

79 dtype: Data type for the tensors. 

80 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument 

81 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing. 

82 

83 Returns: 

84 Tuple of (bin_edges, bin_centers, bin_widths). 

85 

86 Layout: 

87 - 1 edge at 0 

88 - num_bins//2 - 1 edges from 0 to bound_up (log spaced) 

89 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored) 

90 - 2 boundary edges at ±bounds 

91 

92 Total: num_bins + 1 edges creating num_bins bins 

93 """ 

94 if log_spacing: 

95 if not math.isclose(-bound_low, bound_up): 

96 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up") 

97 if bound_up <= 0: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true

98 raise ValueError("log_spacing requires bound_up > 0") 

99 if num_bins % 2 != 0: 

100 raise ValueError("log_spacing requires even number of bins") 

101 

102 half_bins = num_bins // 2 

103 

104 # Create positive side: 0, internal edges, bound_up. 

105 if half_bins == 1: 

106 # Special case where we only use the boundary edges. 

107 positive_edges = torch.tensor([bound_up]) 

108 else: 

109 # Create half_bins - 1 internal edges between 0 and bound_up. 

110 internal_positive = torch.logspace( 

111 start=math.log2(log_min_positive_edge), 

112 end=math.log2(bound_up), 

113 steps=half_bins, 

114 base=2, 

115 ) 

116 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])]) 

117 

118 # Mirror for the negative side (excluding 0). 

119 negative_edges = -positive_edges.flip(0) 

120 

121 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary]. 

122 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges]) 

123 

124 else: 

125 # Linear spacing. 

126 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1) 

127 

128 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5 

129 bin_widths = bin_edges[1:] - bin_edges[:-1] 

130 

131 # Move to specified device and dtype. 

132 bin_edges = bin_edges.to(device=device, dtype=dtype) 

133 bin_centers = bin_centers.to(device=device, dtype=dtype) 

134 bin_widths = bin_widths.to(device=device, dtype=dtype) 

135 

136 return bin_edges, bin_centers, bin_widths 

137 

138 @property 

139 def num_bins(self) -> int: 

140 """Number of bins making up the BinnedLogitCDF.""" 

141 return self.logits.shape[-1] 

142 

143 @property 

144 def num_edges(self) -> int: 

145 """Number of bins edges of the BinnedLogitCDF.""" 

146 return self.bin_edges.shape[0] 

147 

148 @property 

149 def bin_probs(self) -> torch.Tensor: 

150 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).""" 

151 if self.bin_normalization_method == "sigmoid": 

152 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins) 

153 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True) 

154 else: 

155 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins) 

156 return bin_probs 

157 

158 @property 

159 def mean(self) -> torch.Tensor: 

160 """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,).""" 

161 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins) 

162 return torch.sum(weighted_centers, dim=-1) 

163 

164 @property 

165 def variance(self) -> torch.Tensor: 

166 """Compute variance of the distribution, of shape (*batch_shape,).""" 

167 # E[X^2] = weighted squared bin centers. 

168 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins) 

169 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,) 

170 

171 # Var = E[X^2] - E[X]^2 

172 return second_moment - self.mean**2 

173 

174 @property 

175 def support(self) -> constraints.Constraint: 

176 """Support of this distribution. Needs to be limitited to keep the number of bins manageable.""" 

177 return constraints.interval(self.bound_low, self.bound_up) 

178 

179 @property 

180 def arg_constraints(self) -> dict[str, constraints.Constraint]: 

181 """Constraints that should be satisfied by each argument of this distribution. None for this class.""" 

182 return {} 

183 

184 def expand( 

185 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None 

186 ) -> "BinnedLogitCDF": 

187 """Expand distribution to new batch shape. This creates a new instance.""" 

188 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins)) 

189 return BinnedLogitCDF( 

190 logits=expanded_logits, 

191 bound_low=self.bound_low, 

192 bound_up=self.bound_up, 

193 log_spacing=self.log_spacing, 

194 validate_args=self._validate_args, 

195 ) 

196 

197 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 

198 """Compute log probability density at given values. 

199 

200 Args: 

201 value: Values at which to compute the log PDF. 

202 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

203 

204 Returns: 

205 Log PDF values corresponding to the input values. 

206 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

207 """ 

208 return torch.log(self.prob(value) + 1e-8) # small epsilon for stability 

209 

210 def prob(self, value: torch.Tensor) -> torch.Tensor: 

211 """Compute probability density at given values. 

212 

213 Args: 

214 value: Values at which to compute the PDF. 

215 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

216 

217 Returns: 

218 PDF values corresponding to the input values. 

219 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

220 """ 

221 if self._validate_args: 221 ↛ 224line 221 didn't jump to line 224 because the condition on line 221 was always true

222 self._validate_sample(value) 

223 

224 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

225 

226 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). 

227 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): 

228 value = value.expand(self.batch_shape) 

229 

230 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the 

231 # index where value would be inserted to maintain sorted order. 

232 # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index. 

233 value = value.contiguous() 

234 bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape) 

235 

236 # Clamp to valid range [0, num_bins - 1] to handle edge cases: 

237 # - values below bound_low would give bin_idx = -1 

238 # - values at bound_up would give bin_idx = num_bins 

239 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1) 

240 

241 # Gather the bin widths and probabilities for the selected bins. 

242 # For bin_widths of shape (num_bins,) we can index directly. 

243 bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape) 

244 

245 # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension. 

246 # Add sample dimensions to bin_probs and expand to match bin_indices shape. 

247 num_sample_dims = len(bin_indices.shape) - len(self.batch_shape) 

248 bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) 

249 bin_probs_for_gather = bin_probs_for_gather.expand( 

250 *bin_indices.shape, -1 

251 ) # shape: (*sample_shape, *batch_shape, num_bins) 

252 

253 # Gather the selected bin probabilities. 

254 bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) 

255 bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather) 

256 bin_probs_selected = bin_probs_selected.squeeze(-1) 

257 

258 # Compute PDF = probability mass / bin width. 

259 return bin_probs_selected / bin_widths_selected 

260 

261 def cdf(self, value: torch.Tensor) -> torch.Tensor: 

262 """Compute cumulative distribution function at given values. 

263 

264 Args: 

265 value: Values at which to compute the CDF. 

266 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

267 

268 Returns: 

269 CDF values in [0, 1] corresponding to the input values. 

270 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

271 """ 

272 if self._validate_args: 272 ↛ 275line 272 didn't jump to line 275 because the condition on line 272 was always true

273 self._validate_sample(value) 

274 

275 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

276 

277 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). 

278 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): 

279 value = value.expand(self.batch_shape) 

280 

281 # Use binary search to find how many bin centers are <= value. 

282 # torch.searchsorted with right=True gives us the number of elements <= value. 

283 value = value.contiguous() 

284 num_bins_active = torch.searchsorted(self.bin_centers, value, right=True) 

285 

286 # Clamp to valid range [0, num_bins]. 

287 num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape) 

288 

289 # Compute cumulative sum of bin probabilities. 

290 # Prepend 0 for the case where no bins are active. 

291 num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape) 

292 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins) 

293 cumsum_probs = torch.cat( 

294 [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs], 

295 dim=-1, 

296 ) # shape: (*batch_shape, num_bins + 1) 

297 

298 # Expand cumsum_probs to match sample dimensions and gather. 

299 cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape) 

300 cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1) 

301 num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) 

302 cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather) 

303 cdf_values = cdf_values.squeeze(-1) 

304 

305 return cdf_values 

306 

307 def icdf(self, value: torch.Tensor) -> torch.Tensor: 

308 """Compute the inverse CDF, i.e., the quantile function, at the given values. 

309 

310 Args: 

311 value: Values in [0, 1] at which to compute the inverse CDF. 

312 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

313 

314 Returns: 

315 Quantiles in [bound_low, bound_up] corresponding to the input CDF values. 

316 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

317 """ 

318 if self._validate_args and not (value >= 0).all() and (value <= 1).all(): 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true

319 raise ValueError("icdf input must be in [0, 1]") 

320 

321 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

322 

323 # Compute CDF at bin edges. prepend zeros to the cumsum of probabilities as this is always the first edge. 

324 cdf_edges = torch.cat( 

325 [ 

326 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), 

327 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins) 

328 ], 

329 dim=-1, 

330 ) # shape: (*batch_shape, num_bins + 1) 

331 

332 # Determine number of sample dimensions (dimensions before batch_shape). 

333 num_sample_dims = len(value.shape) - len(self.batch_shape) 

334 

335 # Prepend singleton dimensions for sample_shape to cdf_edges. 

336 # cdf_edges: (*batch_shape, num_bins + 1) -> (*sample_shape, *batch_shape, num_bins + 1) 

337 cdf_edges = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape) 

338 

339 # Prepend singleton dimensions for both sample_shape and batch_shape. 

340 # bin_edges: (num_bins + 1,) -> (*sample_shape, *batch_shape, num_bins + 1) 

341 bin_edges_expanded = self.bin_edges.view( 

342 (1,) * (num_sample_dims + len(self.batch_shape)) + self.bin_edges.shape 

343 ) 

344 

345 # Add bin dimension to value for comparison. 

346 value_expanded = value.unsqueeze(-1) 

347 

348 # Find bins containing the value: left_cdf <= value < right_cdf. 

349 bin_mask = (cdf_edges[..., :-1] <= value_expanded) & (value_expanded < cdf_edges[..., 1:]) 

350 bin_mask = bin_mask.to(self.logits.dtype) 

351 

352 # Handle edge case where value ≈ 1.0 (use isclose with dtype-appropriate defaults). 

353 value_is_one = torch.isclose(value_expanded, torch.ones_like(value_expanded)) 

354 bin_mask[..., -1] = torch.max(bin_mask[..., -1], value_is_one[..., 0]) # last bin could be selected already 

355 

356 # Selected the correct bin edges using the mask. Summing is essentially selecting here. 

357 # Summing fast and differentiable. 

358 cfd_value_bin_starts = torch.sum(bin_mask * cdf_edges[..., :-1], dim=-1) 

359 cdf_value_bin_ends = torch.sum(bin_mask * cdf_edges[..., 1:], dim=-1) 

360 bin_left_edges = torch.sum(bin_mask * bin_edges_expanded[..., :-1], dim=-1) 

361 bin_right_edges = torch.sum(bin_mask * bin_edges_expanded[..., 1:], dim=-1) 

362 

363 # Avoid division by zero. 

364 bin_width = cdf_value_bin_ends - cfd_value_bin_starts 

365 safe_bin_width = torch.where(bin_width > 1e-8, bin_width, torch.ones_like(bin_width)) 

366 

367 # Linear interpolation within the bin. 

368 alpha = (value - cfd_value_bin_starts) / safe_bin_width 

369 quantiles = bin_left_edges + alpha * (bin_right_edges - bin_left_edges) 

370 

371 return quantiles 

372 

373 @torch.no_grad() 

374 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor: 

375 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF. 

376 

377 Args: 

378 sample_shape: Shape of the samples to draw. 

379 

380 Returns: 

381 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution. 

382 """ 

383 shape = torch.Size(sample_shape) + self.batch_shape 

384 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) 

385 return self.icdf(uniform_samples) 

386 

387 def entropy(self) -> torch.Tensor: 

388 r"""Compute differential entropy of the distribution. 

389 

390 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) ) 

391 

392 Note: 

393 Here, we are doing an approximation by treating each bin as a uniform distribution over its width. 

394 """ 

395 bin_probs = self.bin_probs 

396 

397 # Get the PDF values at bin centers. 

398 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins) 

399 

400 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i. 

401 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability 

402 entropy_per_bin = -bin_probs * log_pdf 

403 

404 # Sum over bins to get total entropy. 

405 return torch.sum(entropy_per_bin, dim=-1) 

406 

407 def __repr__(self) -> str: 

408 """String representation of the distribution.""" 

409 return ( 

410 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, " 

411 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})" 

412 )