Coverage for binned_cdf/piecewise_constant_binned_cdf.py: 94%

151 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-08 12:02 +0000

1import math 

2from typing import Literal 

3 

4import torch 

5from torch.distributions import Distribution, constraints 

6from torch.nn.functional import log_softmax, logsigmoid 

7 

8_size = torch.Size() 

9 

10 

11class PiecewiseConstantBinnedCDF(Distribution): 

12 """A discrete probability distribution parameterized by binned logits for the CDF. 

13 

14 Each bin contributes a step function to the CDF when active. 

15 The activation of each bin is determined by applying a sigmoid to the corresponding logit. 

16 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing. 

17 

18 Note: 

19 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but 

20 not through the inputs of the `prob` or `cfg` method. 

21 """ 

22 

23 def __init__( 

24 self, 

25 logits: torch.Tensor, 

26 bound_low: float = -1e3, 

27 bound_up: float = 1e3, 

28 log_spacing: bool = False, 

29 normalization_method: Literal["sigmoid", "softmax"] = "sigmoid", 

30 validate_args: bool | None = None, 

31 ) -> None: 

32 """Initializer. 

33 

34 Args: 

35 logits: Raw logits for the bin probabilities (before sigmoid), of shape (*batch_shape, num_bins) 

36 bound_low: Lower bound of the distribution support, needs to be finite. 

37 bound_up: Upper bound of the distribution support, needs to be finite. 

38 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used. 

39 normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid", 

40 each bin is independently activated, while with "softmax", the bins activations influence each other. 

41 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class. 

42 """ 

43 self.logits = logits 

44 self.bound_low = bound_low 

45 self.bound_up = bound_up 

46 self.normalization_method = normalization_method 

47 self.log_spacing = log_spacing 

48 

49 # Create bin structure (same for all batch dimensions). 

50 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins( 

51 num_bins=logits.shape[-1], 

52 bound_low=bound_low, 

53 bound_up=bound_up, 

54 log_spacing=log_spacing, 

55 device=logits.device, 

56 dtype=logits.dtype, 

57 ) 

58 

59 # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution. 

60 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args) 

61 

62 @classmethod 

63 def _create_bins( 

64 cls, 

65 num_bins: int, 

66 bound_low: float, 

67 bound_up: float, 

68 log_spacing: bool, 

69 device: torch.device, 

70 dtype: torch.dtype, 

71 log_min_positive_edge: float = 1e-6, 

72 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

73 """Create bin edges with symmetric log spacing around zero. 

74 

75 Layout: 

76 - 1 edge at 0 

77 - num_bins//2 - 1 edges from 0 to bound_up (log spaced) 

78 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored) 

79 - 2 boundary edges at ±bounds 

80 - in total: num_bins + 1 edges creating num_bins bins 

81 

82 Args: 

83 num_bins: Number of bins to create. 

84 bound_low: Lower bound of the distribution support. 

85 bound_up: Upper bound of the distribution support. 

86 log_spacing: Whether to use logarithmic spacing. 

87 device: Device for the tensors. 

88 dtype: Data type for the tensors. 

89 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument 

90 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing. 

91 

92 Returns: 

93 Tuple of (bin_edges, bin_centers, bin_widths). 

94 """ 

95 if log_spacing: 

96 if not math.isclose(-bound_low, bound_up): 

97 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up") 

98 if bound_up <= 0: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 raise ValueError("log_spacing requires bound_up > 0") 

100 if num_bins % 2 != 0: 

101 raise ValueError("log_spacing requires even number of bins") 

102 

103 half_bins = num_bins // 2 

104 

105 # Create positive side: 0, internal edges, bound_up. 

106 if half_bins == 1: 

107 # Special case where we only use the boundary edges. 

108 positive_edges = torch.tensor([bound_up]) 

109 else: 

110 # Create half_bins - 1 internal edges between 0 and bound_up. 

111 internal_positive = torch.logspace( 

112 start=math.log2(log_min_positive_edge), 

113 end=math.log2(bound_up), 

114 steps=half_bins, 

115 base=2, 

116 ) 

117 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])]) 

118 

119 # Mirror for the negative side (excluding 0). 

120 negative_edges = -positive_edges.flip(0) 

121 

122 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary]. 

123 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges]) 

124 

125 else: 

126 # Linear spacing. 

127 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1) 

128 

129 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5 

130 bin_widths = bin_edges[1:] - bin_edges[:-1] 

131 

132 # Move to specified device and dtype. 

133 bin_edges = bin_edges.to(device=device, dtype=dtype) 

134 bin_centers = bin_centers.to(device=device, dtype=dtype) 

135 bin_widths = bin_widths.to(device=device, dtype=dtype) 

136 

137 return bin_edges, bin_centers, bin_widths 

138 

139 @property 

140 def num_bins(self) -> int: 

141 """Number of bins making up the PiecewiseConstantBinnedCDF.""" 

142 return self.logits.shape[-1] 

143 

144 @property 

145 def num_edges(self) -> int: 

146 """Number of bins edges of the PiecewiseConstantBinnedCDF.""" 

147 return self.bin_edges.shape[0] 

148 

149 @property 

150 def bin_probs(self) -> torch.Tensor: 

151 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).""" 

152 if self.normalization_method == "sigmoid": 

153 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins) 

154 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True) 

155 else: 

156 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins) 

157 return bin_probs 

158 

159 @property 

160 def mean(self) -> torch.Tensor: 

161 """Compute mean of the distribution, i.e., the weighted average of bin centers. 

162 

163 Returns: 

164 Tensor of shape (*batch_shape,). 

165 """ 

166 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins) 

167 return torch.sum(weighted_centers, dim=-1) 

168 

169 @property 

170 def variance(self) -> torch.Tensor: 

171 """Compute variance of the distribution. 

172 

173 Returns: 

174 Tensor of shape (*batch_shape,). 

175 """ 

176 # E[X^2] = weighted squared bin centers. 

177 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins) 

178 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,) 

179 

180 # Var = E[X^2] - E[X]^2. 

181 return second_moment - self.mean**2 

182 

183 @property 

184 def support(self) -> constraints.Constraint: 

185 """Support of this distribution. The resolution also depends on the number of bins.""" 

186 return constraints.interval(self.bound_low, self.bound_up) 

187 

188 @property 

189 def arg_constraints(self) -> dict[str, constraints.Constraint]: 

190 """Constraints that should be satisfied by each argument of this distribution. None for this class.""" 

191 return {"logits": constraints.real} 

192 

193 def expand( 

194 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None 

195 ) -> "PiecewiseConstantBinnedCDF": 

196 """Expand distribution to new batch shape. This creates a new instance.""" 

197 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins)) 

198 return self.__class__( 

199 logits=expanded_logits, 

200 bound_low=self.bound_low, 

201 bound_up=self.bound_up, 

202 log_spacing=self.log_spacing, 

203 validate_args=self._validate_args, 

204 ) 

205 

206 def _prepare_input(self, value: torch.Tensor) -> tuple[torch.Tensor, int]: 

207 """Prepare the input tensor for `log_prob`, `prob`, `cdf` and `icdf` computations. 

208 

209 This method handles device/dtype transfer, batch dimension alignment, and broadcasting. 

210 

211 Args: 

212 value: Input tensor to prepare. Expected shape: `(*sample_shape, *batch_shape)` or broadcastable to it. 

213 For example, if `batch_shape` is `(B1, B2)` and `value` is `(S1, S2)`, it will be broadcast to 

214 `(S1, S2, B1, B2)`. If `value` is `(B1, B2)` (no sample dims), it remains `(B1, B2)`. 

215 

216 Returns: 

217 A tuple containing: 

218 - Prepared `value` tensor, of shape: `(*sample_shape, *batch_shape)`. 

219 - `num_sample_dims`: The number of sample dimensions in the prepared `value` tensor. 

220 """ 

221 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

222 

223 # This ensures the batch dimension is the last dimension. 

224 if len(self.batch_shape) > 0: # noqa: SIM102 

225 # Check if the rightmost dimensions of value match batch_shape. 

226 # If they don't, we assume value is missing the batch dimensions. 

227 if value.shape[-len(self.batch_shape) :] != self.batch_shape: 

228 value = value.unsqueeze(-1) 

229 

230 num_sample_dims = max(0, value.ndim - len(self.batch_shape)) 

231 target_shape = self._extended_shape(sample_shape=value.shape[:num_sample_dims]) 

232 value = value.expand(target_shape) 

233 value = value.contiguous() # for searchsorted later 

234 

235 return value, num_sample_dims 

236 

237 def _get_bin_indices( 

238 self, value: torch.Tensor, bin_edges: torch.Tensor | None = None, bin_centers: torch.Tensor | None = None 

239 ) -> torch.Tensor: 

240 """Get bin indices for the given values using binary search. 

241 

242 Args: 

243 value: Input tensor of shape (*sample_shape, *batch_shape). 

244 bin_edges: Tensor of bin edges, of shape (num_bins + 1,). If provided, the bin indices are determined based 

245 on the edges. 

246 bin_centers: Tensor of bin centers, of shape (num_bins,). If provided, the bin indices are determined based 

247 on the centers. 

248 

249 Returns: 

250 Tensor of bin indices for the given values, of shape (*sample_shape, *batch_shape), with values in 

251 [0, num_bins - 1] if the bins are defined by their edges or with values in [0, num_bins] if the bins are 

252 defined by their centers. 

253 """ 

254 if bin_edges is not None and bin_centers is not None: 254 ↛ 255line 254 didn't jump to line 255 because the condition on line 254 was never true

255 raise ValueError("Provide either edges or centers as input, not both.") 

256 

257 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the 

258 # index where value would be inserted to maintain sorted order. 

259 if bin_edges is not None: 

260 # Since bins are defined as [edge[i], edge[i + 1]), we subtract 1 to get the bin index. 

261 bin_indices = torch.searchsorted(bin_edges, value, right=True) - 1 

262 elif bin_centers is not None: 262 ↛ 267line 262 didn't jump to line 267 because the condition on line 262 was always true

263 # If value < first center, returns 0 -> gets 0.0 from cumsum_probs 

264 # If value >= last center, returns num_bins -> gets 1.0 from cumsum_probs 

265 bin_indices = torch.searchsorted(bin_centers, value, right=True) 

266 else: 

267 raise ValueError("Either edges or centers must be provided to determine bin indices.") 

268 

269 # Clamp the output of torch.searchsorted to valid range to handle edge cases: 

270 # - values below bound_low would give bin_idx = -1 

271 # - values at bound_up would give bin_idx = num_bins 

272 if bin_edges is not None: 

273 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1) 

274 elif bin_centers is not None: 274 ↛ 277line 274 didn't jump to line 277 because the condition on line 274 was always true

275 bin_indices = torch.clamp(bin_indices, 0, self.num_bins) 

276 

277 return bin_indices 

278 

279 def _gather_from_bins( 

280 self, params: torch.Tensor, bin_indices: torch.Tensor, num_sample_dims: int, target_shape: torch.Size 

281 ) -> torch.Tensor: 

282 """Gather bin-specific parameters using aligned indices. 

283 

284 Args: 

285 params: Tensor used as the input to gather from, of shape (*batch_shape, num_bins) or 

286 (*batch_shape, num_bins + 1). 

287 bin_indices: Indices used to gather by, of shape (*sample_shape, *batch_shape). 

288 num_sample_dims: Number of leading sample dimensions in the input. 

289 target_shape: The shape to expand to, (*sample_shape, *batch_shape). 

290 

291 Returns: 

292 Gathered values of shape (*sample_shape, *batch_shape). 

293 """ 

294 # Add singleton dimensions for sample_shape: (1, ..., 1, *batch_shape, num_bins). 

295 params_view = params.view((1,) * num_sample_dims + params.shape) 

296 

297 # Expand to match the full target shape of the input. 

298 params_expanded = params_view.expand(*target_shape, -1) 

299 

300 # Gather along the last dimension. The index must be unsqueezed if indices doesn't have the bin dim yet. 

301 # Use gather with automatic broadcasting. unsqueeze(-1) provides the index dimension, 

302 # and squeeze(-1) removes it from the result. 

303 if bin_indices.ndim == len(target_shape): 

304 bin_indices = bin_indices.unsqueeze(-1) 

305 gathered = torch.gather(params_expanded, dim=-1, index=bin_indices).squeeze(-1) 

306 

307 return gathered 

308 

309 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 

310 """Compute the log-probability density at given values. 

311 

312 Args: 

313 value: Values at which to compute the log-PDF. 

314 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

315 

316 Returns: 

317 Log-PDF values corresponding to the input values. 

318 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

319 """ 

320 if self._validate_args: 320 ↛ 323line 320 didn't jump to line 323 because the condition on line 320 was always true

321 self._validate_sample(value) 

322 

323 value_prep, num_sample_dims = self._prepare_input(value) 

324 

325 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

326 

327 # Calculate the log-probabilities directly for stability. 

328 if self.normalization_method == "sigmoid": 

329 # Normalized logsigmoid: log(sigmoid(x) / sum(sigmoid(x))) 

330 log_raw = logsigmoid(self.logits) 

331 log_normalization = torch.logsumexp(log_raw, dim=-1, keepdim=True) 

332 log_bin_probs = log_raw - log_normalization 

333 else: 

334 log_bin_probs = log_softmax(self.logits, dim=-1) 

335 

336 log_probs = self._gather_from_bins(log_bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

337 

338 return log_probs 

339 

340 def prob(self, value: torch.Tensor) -> torch.Tensor: 

341 """Compute probability density at given values. 

342 

343 Args: 

344 value: Values at which to compute the PDF. 

345 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

346 

347 Returns: 

348 PDF values corresponding to the input values. 

349 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

350 """ 

351 if self._validate_args: 351 ↛ 354line 351 didn't jump to line 354 because the condition on line 351 was always true

352 self._validate_sample(value) 

353 

354 value_prep, num_sample_dims = self._prepare_input(value) 

355 

356 bin_indices = self._get_bin_indices(value_prep, bin_edges=self.bin_edges) 

357 

358 probs = self._gather_from_bins(self.bin_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

359 

360 return probs 

361 

362 def cdf(self, value: torch.Tensor) -> torch.Tensor: 

363 """Compute cumulative distribution function at given values. 

364 

365 Args: 

366 value: Values at which to compute the CDF. 

367 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

368 

369 Returns: 

370 CDF values in [0, 1] corresponding to the input values. 

371 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

372 """ 

373 if self._validate_args: 373 ↛ 376line 373 didn't jump to line 376 because the condition on line 373 was always true

374 self._validate_sample(value) 

375 

376 value_prep, num_sample_dims = self._prepare_input(value) 

377 

378 bin_indices = self._get_bin_indices(value_prep, bin_centers=self.bin_centers) 

379 

380 # Compute the cumulative sum of bin probabilities. 

381 # Prepend 0 for the case where no bins are active. 

382 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins) 

383 zero_prefix = torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device) 

384 cumsum_probs = torch.cat([zero_prefix, cumsum_probs], dim=-1) # shape: (*batch_shape, num_bins + 1) 

385 

386 cdf_values = self._gather_from_bins(cumsum_probs, bin_indices, num_sample_dims, target_shape=value_prep.shape) 

387 

388 return cdf_values 

389 

390 def icdf(self, value: torch.Tensor) -> torch.Tensor: 

391 """Compute the inverse CDF, i.e., the quantile function, at the given values. 

392 

393 Args: 

394 value: Values in [0, 1] at which to compute the inverse CDF. 

395 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

396 

397 Returns: 

398 Quantiles in [bound_low, bound_up] corresponding to the input CDF values. 

399 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

400 """ 

401 if self._validate_args: 401 ↛ 404line 401 didn't jump to line 404 because the condition on line 401 was always true

402 self._validate_sample(value) 

403 

404 value_prep, num_sample_dims = self._prepare_input(value) 

405 

406 # Compute CDF at bin edges. Prepend zeros to the cumsum of probabilities as this is always the first edge. 

407 cdf_edges = torch.cat( 

408 [ 

409 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), 

410 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins) 

411 ], 

412 dim=-1, 

413 ) # [0, p1, p1+p2, ..., 1.0], shape: (*batch_shape, num_bins + 1) 

414 

415 # Prepend singleton dimensions for sample_shape to cdf_edges and expand to match value. 

416 cdf_edges_expanded = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape) 

417 cdf_edges_expanded = cdf_edges_expanded.expand(*value_prep.shape, -1) 

418 cdf_edges_expanded = cdf_edges_expanded.contiguous() 

419 

420 bin_indices = self._get_bin_indices(value_prep.unsqueeze(-1), bin_edges=cdf_edges_expanded) 

421 

422 quantiles = self._gather_from_bins( 

423 self.bin_centers, bin_indices, num_sample_dims, target_shape=value_prep.shape 

424 ) 

425 

426 return quantiles # shape: (*sample_shape, *batch_shape) 

427 

428 @torch.no_grad() 

429 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor: 

430 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF. 

431 

432 Args: 

433 sample_shape: Shape of the samples to draw. 

434 

435 Returns: 

436 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution. 

437 """ 

438 # Determine the final shape of the output tensor. 

439 shape = self._extended_shape(sample_shape) 

440 

441 # Sample in [0, 1] and transform through inverse CDF to get samples in [bound_low, bound_up]. 

442 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) 

443 samples = self.icdf(uniform_samples) 

444 

445 return samples 

446 

447 def entropy(self) -> torch.Tensor: 

448 r"""Compute Shannon entropy of the discrete distribution. 

449 

450 $$H[X] = -\sum_{i=1}^{n} p_i \log p_i$$ 

451 where $p_i$ is the probability mass of bin $i$. 

452 

453 Returns: 

454 Tensor of shape (*batch_shape,). 

455 """ 

456 bin_probs = self.bin_probs 

457 

458 # Compute entropy per bin and sum over bins. Add small epsilon for numerical stability in log. 

459 entropy_per_bin = bin_probs * torch.log(bin_probs + 1e-8) # shape: (*batch_shape, num_bins) 

460 

461 # Sum over bins to get total entropy. 

462 return -torch.sum(entropy_per_bin, dim=-1) 

463 

464 def __repr__(self) -> str: 

465 """String representation of the distribution.""" 

466 return ( 

467 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, " 

468 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing}, " 

469 f"normalization_method: {self.normalization_method})" 

470 )