Coverage for binned_cdf / piecewise_constant_binned_cdf.py: 94%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-09 09:21 +0000

1import math 

2from typing import Literal 

3 

4import torch 

5from torch.distributions import Distribution, constraints 

6from torch.nn.functional import log_softmax, logsigmoid 

7 

8_size = torch.Size() 

9 

10 

11class PiecewiseConstantBinnedCDF(Distribution): 

12 """A discrete probability distribution parameterized by binned logits for the CDF. 

13 

14 Each bin contributes a step function to the CDF when active. 

15 The activation of each bin is determined by applying a sigmoid to the corresponding logit. 

16 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing. 

17 

18 Note: 

19 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but 

20 not through the inputs of the `prob` or `cfg` method. 

21 """ 

22 

23 def __init__( 

24 self, 

25 logits: torch.Tensor, 

26 bound_low: float = -1e3, 

27 bound_up: float = 1e3, 

28 log_spacing: bool = False, 

29 bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid", 

30 validate_args: bool | None = None, 

31 ) -> None: 

32 """Initializer. 

33 

34 Args: 

35 logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins) 

36 bound_low: Lower bound of the distribution support, needs to be finite. 

37 bound_up: Upper bound of the distribution support, needs to be finite. 

38 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used. 

39 bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid", 

40 each bin is independently activated, while with "softmax", the bins activations influence each other. 

41 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class. 

42 """ 

43 self.logits = logits 

44 self.bound_low = bound_low 

45 self.bound_up = bound_up 

46 self.bin_normalization_method = bin_normalization_method 

47 self.log_spacing = log_spacing 

48 

49 # Create bin structure (same for all batch dimensions). 

50 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins( 

51 num_bins=logits.shape[-1], 

52 bound_low=bound_low, 

53 bound_up=bound_up, 

54 log_spacing=log_spacing, 

55 device=logits.device, 

56 dtype=logits.dtype, 

57 ) 

58 

59 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args) 

60 

61 @classmethod 

62 def _create_bins( 

63 cls, 

64 num_bins: int, 

65 bound_low: float, 

66 bound_up: float, 

67 log_spacing: bool, 

68 device: torch.device, 

69 dtype: torch.dtype, 

70 log_min_positive_edge: float = 1e-6, 

71 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

72 """Create bin edges with symmetric log spacing around zero. 

73 

74 Args: 

75 num_bins: Number of bins to create. 

76 bound_low: Lower bound of the distribution support. 

77 bound_up: Upper bound of the distribution support. 

78 log_spacing: Whether to use logarithmic spacing. 

79 device: Device for the tensors. 

80 dtype: Data type for the tensors. 

81 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument 

82 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing. 

83 

84 Returns: 

85 Tuple of (bin_edges, bin_centers, bin_widths). 

86 

87 Layout: 

88 - 1 edge at 0 

89 - num_bins//2 - 1 edges from 0 to bound_up (log spaced) 

90 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored) 

91 - 2 boundary edges at ±bounds 

92 

93 Total: num_bins + 1 edges creating num_bins bins 

94 """ 

95 if log_spacing: 

96 if not math.isclose(-bound_low, bound_up): 

97 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up") 

98 if bound_up <= 0: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 raise ValueError("log_spacing requires bound_up > 0") 

100 if num_bins % 2 != 0: 

101 raise ValueError("log_spacing requires even number of bins") 

102 

103 half_bins = num_bins // 2 

104 

105 # Create positive side: 0, internal edges, bound_up. 

106 if half_bins == 1: 

107 # Special case where we only use the boundary edges. 

108 positive_edges = torch.tensor([bound_up]) 

109 else: 

110 # Create half_bins - 1 internal edges between 0 and bound_up. 

111 internal_positive = torch.logspace( 

112 start=math.log2(log_min_positive_edge), 

113 end=math.log2(bound_up), 

114 steps=half_bins, 

115 base=2, 

116 ) 

117 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])]) 

118 

119 # Mirror for the negative side (excluding 0). 

120 negative_edges = -positive_edges.flip(0) 

121 

122 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary]. 

123 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges]) 

124 

125 else: 

126 # Linear spacing. 

127 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1) 

128 

129 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5 

130 bin_widths = bin_edges[1:] - bin_edges[:-1] 

131 

132 # Move to specified device and dtype. 

133 bin_edges = bin_edges.to(device=device, dtype=dtype) 

134 bin_centers = bin_centers.to(device=device, dtype=dtype) 

135 bin_widths = bin_widths.to(device=device, dtype=dtype) 

136 

137 return bin_edges, bin_centers, bin_widths 

138 

139 @property 

140 def num_bins(self) -> int: 

141 """Number of bins making up the PiecewiseConstantBinnedCDF.""" 

142 return self.logits.shape[-1] 

143 

144 @property 

145 def num_edges(self) -> int: 

146 """Number of bins edges of the PiecewiseConstantBinnedCDF.""" 

147 return self.bin_edges.shape[0] 

148 

149 @property 

150 def bin_probs(self) -> torch.Tensor: 

151 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).""" 

152 if self.bin_normalization_method == "sigmoid": 

153 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins) 

154 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True) 

155 else: 

156 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins) 

157 return bin_probs 

158 

159 @property 

160 def mean(self) -> torch.Tensor: 

161 """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,).""" 

162 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins) 

163 return torch.sum(weighted_centers, dim=-1) 

164 

165 @property 

166 def variance(self) -> torch.Tensor: 

167 """Compute variance of the distribution, of shape (*batch_shape,).""" 

168 # E[X^2] = weighted squared bin centers. 

169 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins) 

170 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,) 

171 

172 # Var = E[X^2] - E[X]^2 

173 return second_moment - self.mean**2 

174 

175 @property 

176 def support(self) -> constraints.Constraint: 

177 """Support of this distribution. Needs to be limitited to keep the number of bins manageable.""" 

178 return constraints.interval(self.bound_low, self.bound_up) 

179 

180 @property 

181 def arg_constraints(self) -> dict[str, constraints.Constraint]: 

182 """Constraints that should be satisfied by each argument of this distribution. None for this class.""" 

183 return {} 

184 

185 def expand( 

186 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None 

187 ) -> "PiecewiseConstantBinnedCDF": 

188 """Expand distribution to new batch shape. This creates a new instance.""" 

189 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins)) 

190 return self.__class__( 

191 logits=expanded_logits, 

192 bound_low=self.bound_low, 

193 bound_up=self.bound_up, 

194 log_spacing=self.log_spacing, 

195 validate_args=self._validate_args, 

196 ) 

197 

198 def _prepare_input(self, value: torch.Tensor) -> tuple[torch.Tensor, int]: 

199 """Prepare the input tensor for `log_prob`, `prob`, `cdf` and `icdf` computations. 

200 

201 This method handles device/dtype transfer, batch dimension alignment, and broadcasting. 

202 

203 Args: 

204 value: Input tensor to prepare. Expected shape: `(*sample_shape, *batch_shape)` or broadcastable to it. 

205 For example, if `batch_shape` is `(B1, B2)` and `value` is `(S1, S2)`, it will be broadcast to 

206 `(S1, S2, B1, B2)`. If `value` is `(B1, B2)` (no sample dims), it remains `(B1, B2)`. 

207 

208 Returns: 

209 A tuple containing: 

210 - Prepared `value` tensor, of shape: `(*sample_shape, *batch_shape)`. 

211 - `num_sample_dims`: The number of sample dimensions in the prepared `value` tensor. 

212 """ 

213 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

214 

215 # This ensures the batch dimension is the last dimension. 

216 if len(self.batch_shape) > 0: # noqa: SIM102 

217 # Check if the rightmost dimensions of value match batch_shape. 

218 # If they don't, we assume value is missing the batch dimensions. 

219 if value.shape[-len(self.batch_shape) :] != self.batch_shape: 

220 value = value.unsqueeze(-1) 

221 

222 num_sample_dims = max(0, value.ndim - len(self.batch_shape)) 

223 target_shape = torch.Size(value.shape[:num_sample_dims]) + self.batch_shape 

224 value = value.expand(target_shape) 

225 value = value.contiguous() # for searchsorted later 

226 

227 return value, num_sample_dims 

228 

229 def _get_bin_indices( 

230 self, value: torch.Tensor, bin_edges: torch.Tensor | None = None, bin_centers: torch.Tensor | None = None 

231 ) -> torch.Tensor: 

232 """Get bin indices for the given values using binary search. 

233 

234 Args: 

235 value: Input tensor of shape (*sample_shape, *batch_shape). 

236 bin_edges: Tensor of bin edges, of shape (num_bins + 1,). If provided, the bin indices are determined based 

237 on the edges. 

238 bin_centers: Tensor of bin centers, of shape (num_bins,). If provided, the bin indices are determined based 

239 on the centers. 

240 

241 Returns: 

242 Tensor of bin indices for the given values, of shape (*sample_shape, *batch_shape), with values in 

243 [0, num_bins - 1] if the bins are defined by their edges or with values in [0, num_bins] if the bins are 

244 defined by their centers. 

245 """ 

246 if bin_edges is not None and bin_centers is not None: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true

247 raise ValueError("Provide either edges or centers as input, not both.") 

248 

249 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the 

250 # index where value would be inserted to maintain sorted order. 

251 if bin_edges is not None: 

252 # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index. 

253 bin_indices = torch.searchsorted(bin_edges, value, right=True) - 1 

254 elif bin_centers is not None: 254 ↛ 259line 254 didn't jump to line 259 because the condition on line 254 was always true

255 # If value < first center, returns 0 -> gets 0.0 from cumsum_probs 

256 # If value >= last center, returns num_bins -> gets 1.0 from cumsum_probs 

257 bin_indices = torch.searchsorted(bin_centers, value, right=True) 

258 else: 

259 raise ValueError("Either edges or centers must be provided to determine bin indices.") 

260 

261 # Clamp the output of torch.searchsorted to valid range to handle edge cases: 

262 # - values below bound_low would give bin_idx = -1 

263 # - values at bound_up would give bin_idx = num_bins 

264 if bin_edges is not None: 

265 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1) 

266 elif bin_centers is not None: 266 ↛ 269line 266 didn't jump to line 269 because the condition on line 266 was always true

267 bin_indices = torch.clamp(bin_indices, 0, self.num_bins) 

268 

269 return bin_indices 

270 

271 def _gather_from_bins( 

272 self, params: torch.Tensor, bin_indices: torch.Tensor, num_sample_dims: int, target_shape: torch.Size 

273 ) -> torch.Tensor: 

274 """Gather bin-specific parameters using aligned indices. 

275 

276 Args: 

277 params: Tensor used as the input to gather from, of shape (*batch_shape, num_bins) or 

278 (*batch_shape, num_bins + 1). 

279 bin_indices: Indices used to gather by, of shape (*sample_shape, *batch_shape). 

280 num_sample_dims: Number of leading sample dimensions in the input. 

281 target_shape: The shape to expand to, (*sample_shape, *batch_shape). 

282 

283 Returns: 

284 Gathered values of shape (*sample_shape, *batch_shape). 

285 """ 

286 # Add singleton dimensions for sample_shape: (1, ..., 1, *batch_shape, num_bins). 

287 params_view = params.view((1,) * num_sample_dims + params.shape) 

288 

289 # Expand to match the full target shape of the input. 

290 params_expanded = params_view.expand(*target_shape, -1) 

291 

292 # Gather along the last dimension. The index must be unsqueezed if indices doesn't have the bin dim yet. 

293 # Use gather with automatic broadcasting. unsqueeze(-1) provides the index dimension, 

294 # and squeeze(-1) removes it from the result. 

295 if bin_indices.ndim == len(target_shape): 

296 bin_indices = bin_indices.unsqueeze(-1) 

297 gathered = torch.gather(params_expanded, dim=-1, index=bin_indices).squeeze(-1) 

298 

299 return gathered 

300 

301 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 

302 """Compute the log-probability density at given values. 

303 

304 Args: 

305 value: Values at which to compute the log PDF. 

306 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

307 

308 Returns: 

309 Log PDF values corresponding to the input values. 

310 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

311 """ 

312 if self._validate_args: 312 ↛ 315line 312 didn't jump to line 315 because the condition on line 312 was always true

313 self._validate_sample(value) 

314 

315 value_prep, num_sample_dims = self._prepare_input(value) 

316 

317 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

318 

319 # Calculate the log-probabilities directly for stability. 

320 if self.bin_normalization_method == "sigmoid": 

321 # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x))) 

322 log_raw = logsigmoid(self.logits) 

323 log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True) 

324 log_bin_probs = log_raw - log_normalization 

325 else: 

326 log_bin_probs = log_softmax(self.logits, dim=-1) 

327 

328 log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

329 

330 return log_probs 

331 

332 def prob(self, value: torch.Tensor) -> torch.Tensor: 

333 """Compute probability density at given values. 

334 

335 Args: 

336 value: Values at which to compute the PDF. 

337 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

338 

339 Returns: 

340 PDF values corresponding to the input values. 

341 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

342 """ 

343 if self._validate_args: 343 ↛ 346line 343 didn't jump to line 346 because the condition on line 343 was always true

344 self._validate_sample(value) 

345 

346 value_prep, num_sample_dims = self._prepare_input(value) 

347 

348 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

349 

350 probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

351 

352 return probs 

353 

354 def cdf(self, value: torch.Tensor) -> torch.Tensor: 

355 """Compute cumulative distribution function at given values. 

356 

357 Args: 

358 value: Values at which to compute the CDF. 

359 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

360 

361 Returns: 

362 CDF values in [0, 1] corresponding to the input values. 

363 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

364 """ 

365 if self._validate_args: 365 ↛ 368line 365 didn't jump to line 368 because the condition on line 365 was always true

366 self._validate_sample(value) 

367 

368 value_prep, num_sample_dims = self._prepare_input(value) 

369 

370 bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers) 

371 

372 # Compute the cumulative sum of bin probabilities. 

373 # Prepend 0 for the case where no bins are active. 

374 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins) 

375 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device) 

376 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1) 

377 

378 cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

379 

380 return cdf_values 

381 

382 def icdf(self, value: torch.Tensor) -> torch.Tensor: 

383 """Compute the inverse CDF, i.e., the quantile function, at the given values. 

384 

385 Args: 

386 value: Values in [0, 1] at which to compute the inverse CDF. 

387 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

388 

389 Returns: 

390 Quantiles in [bound_low, bound_up] corresponding to the input CDF values. 

391 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

392 """ 

393 if self._validate_args: 393 ↛ 396line 393 didn't jump to line 396 because the condition on line 393 was always true

394 self._validate_sample(value) 

395 

396 value_prep, num_sample_dims = self._prepare_input(value) 

397 

398 # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge. 

399 cdf_edges = torch.cat( 

400 [ 

401 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), 

402 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins) 

403 ], 

404 dim=-1, 

405 ) # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1) 

406 

407 # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value. 

408 cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape) 

409 cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1) 

410 cdf_edges_expanded = cdf_edges_expanded.contiguous() 

411 

412 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded) 

413 

414 quantiles = self._gather_from_bins( 

415 self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape 

416 ) 

417 

418 return quantiles # shape: (*sample_shape, *batch_shape) 

419 

420 @torch.no_grad() 

421 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor: 

422 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF. 

423 

424 Args: 

425 sample_shape: Shape of the samples to draw. 

426 

427 Returns: 

428 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution. 

429 """ 

430 shape = torch.Size(sample_shape) + self.batch_shape 

431 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) 

432 return self.icdf(uniform_samples) 

433 

434 def entropy(self) -> torch.Tensor: 

435 r"""Compute Shannon entropy of the discrete distribution. 

436 

437 $$H(X) = -\sum_{i=1}^{n} p_i \log p_i$$ 

438 where $p_i$ is the probability mass of bin $i$. 

439 """ 

440 bin_probs = self.bin_probs 

441 

442 # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log. 

443 entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8) # shape: (*batch_shape, num_bins) 

444 

445 # Sum over bins to get total entropy. 

446 return -torch.sum(entropy_per_bin, dim=-1) 

447 

448 def __repr__(self) -> str: 

449 """String representation of the distribution.""" 

450 return ( 

451 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, " 

452 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})" 

453 )