Coverage for binned_cdf / binned_logit_cdf.py: 96%

138 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-05 16:38 +0000

1import math 

2from typing import Literal 

3 

4import torch 

5from torch.distributions import Distribution, constraints 

6 

7_size = torch.Size() 

8 

9 

10class BinnedLogitCDF(Distribution): 

11 """A histogram-based probability distribution parameterized by a bins for the CDF. 

12 

13 Each bin contributes a step function to the CDF when active. 

14 The activation of each bin is determined by applying a sigmoid to the corresponding logit. 

15 The distribution is defined over the interval [bound_low, bound_up] with either linear or logarithmic bin spacing. 

16 

17 Note: 

18 This distribution is differentiable with respect to the logits, i.e., the arguments of `__init__`, but 

19 not through the inputs of the `prob` or `cfg` method. 

20 """ 

21 

22 def __init__( 

23 self, 

24 logits: torch.Tensor, 

25 bound_low: float = -1e3, 

26 bound_up: float = 1e3, 

27 log_spacing: bool = False, 

28 bin_normalization_method: Literal["sigmoid", "softmax"] = "sigmoid", 

29 validate_args: bool | None = None, 

30 ) -> None: 

31 """Initializer. 

32 

33 Args: 

34 logits: Raw logits for bin probabilities (before sigmoid), of shape (*batch_shape, num_bins) 

35 bound_low: Lower bound of the distribution support, needs to be finite. 

36 bound_up: Upper bound of the distribution support, needs to be finite. 

37 log_spacing: Whether logarithmic (base = 2) spacing for the bins or linear spacing should be used. 

38 bin_normalization_method: How to normalize bin probabilities. Either "sigmoid" or "softmax". With "sigmoid", 

39 each bin is independently activated, while with "softmax", the bins activations influence each other. 

40 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class. 

41 """ 

42 self.logits = logits 

43 self.bound_low = bound_low 

44 self.bound_up = bound_up 

45 self.bin_normalization_method = bin_normalization_method 

46 self.log_spacing = log_spacing 

47 

48 # Create bin structure (same for all batch dimensions). 

49 self.bin_edges, self.bin_centers, self.bin_widths = self._create_bins( 

50 num_bins=logits.shape[-1], 

51 bound_low=bound_low, 

52 bound_up=bound_up, 

53 log_spacing=log_spacing, 

54 device=logits.device, 

55 dtype=logits.dtype, 

56 ) 

57 

58 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args) 

59 

60 @classmethod 

61 def _create_bins( 

62 cls, 

63 num_bins: int, 

64 bound_low: float, 

65 bound_up: float, 

66 log_spacing: bool, 

67 device: torch.device, 

68 dtype: torch.dtype, 

69 log_min_positive_edge: float = 1e-6, 

70 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

71 """Create bin edges with symmetric log spacing around zero. 

72 

73 Args: 

74 num_bins: Number of bins to create. 

75 bound_low: Lower bound of the distribution support. 

76 bound_up: Upper bound of the distribution support. 

77 log_spacing: Whether to use logarithmic spacing. 

78 device: Device for the tensors. 

79 dtype: Data type for the tensors. 

80 log_min_positive_edge: Minimum positive edge when using log spacing. The log2-value of this argument 

81 will be passed to torch.logspace. Too small values, approx below 1e-9, will result in poor bin spacing. 

82 

83 Returns: 

84 Tuple of (bin_edges, bin_centers, bin_widths). 

85 

86 Layout: 

87 - 1 edge at 0 

88 - num_bins//2 - 1 edges from 0 to bound_up (log spaced) 

89 - num_bins//2 - 1 edges from 0 to -bound_low (log spaced, mirrored) 

90 - 2 boundary edges at ±bounds 

91 

92 Total: num_bins + 1 edges creating num_bins bins 

93 """ 

94 if log_spacing: 

95 if not math.isclose(-bound_low, bound_up): 

96 raise ValueError("log_spacing requires symmetric bounds: -bound_low == bound_up") 

97 if bound_up <= 0: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true

98 raise ValueError("log_spacing requires bound_up > 0") 

99 if num_bins % 2 != 0: 

100 raise ValueError("log_spacing requires even number of bins") 

101 

102 half_bins = num_bins // 2 

103 

104 # Create positive side: 0, internal edges, bound_up. 

105 if half_bins == 1: 

106 # Special case where we only use the boundary edges. 

107 positive_edges = torch.tensor([bound_up]) 

108 else: 

109 # Create half_bins - 1 internal edges between 0 and bound_up. 

110 internal_positive = torch.logspace( 

111 start=math.log2(log_min_positive_edge), 

112 end=math.log2(bound_up), 

113 steps=half_bins, 

114 base=2, 

115 ) 

116 positive_edges = torch.cat([internal_positive[:-1], torch.tensor([bound_up])]) 

117 

118 # Mirror for the negative side (excluding 0). 

119 negative_edges = -positive_edges.flip(0) 

120 

121 # Combine to [negative_boundary, negative_internal, 0, positive_internal, positive_boundary]. 

122 bin_edges = torch.cat([negative_edges, torch.tensor([0.0]), positive_edges]) 

123 

124 else: 

125 # Linear spacing. 

126 bin_edges = torch.linspace(start=bound_low, end=bound_up, steps=num_bins + 1) 

127 

128 bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5 

129 bin_widths = bin_edges[1:] - bin_edges[:-1] 

130 

131 # Move to specified device and dtype. 

132 bin_edges = bin_edges.to(device=device, dtype=dtype) 

133 bin_centers = bin_centers.to(device=device, dtype=dtype) 

134 bin_widths = bin_widths.to(device=device, dtype=dtype) 

135 

136 return bin_edges, bin_centers, bin_widths 

137 

138 @property 

139 def num_bins(self) -> int: 

140 """Number of bins making up the BinnedLogitCDF.""" 

141 return self.logits.shape[-1] 

142 

143 @property 

144 def num_edges(self) -> int: 

145 """Number of bins edges of the BinnedLogitCDF.""" 

146 return self.bin_edges.shape[0] 

147 

148 @property 

149 def bin_probs(self) -> torch.Tensor: 

150 """Get normalized probabilities for each bin, of shape (*batch_shape, num_bins).""" 

151 if self.bin_normalization_method == "sigmoid": 

152 raw_probs = torch.sigmoid(self.logits) # shape: (*batch_shape, num_bins) 

153 bin_probs = raw_probs / raw_probs.sum(dim=-1, keepdim=True) 

154 else: 

155 bin_probs = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, num_bins) 

156 return bin_probs 

157 

158 @property 

159 def mean(self) -> torch.Tensor: 

160 """Compute mean of the distribution, i.e., the weighted average of bin centers, of shape (*batch_shape,).""" 

161 weighted_centers = self.bin_probs * self.bin_centers # shape: (*batch_shape, num_bins) 

162 return torch.sum(weighted_centers, dim=-1) 

163 

164 @property 

165 def variance(self) -> torch.Tensor: 

166 """Compute variance of the distribution, of shape (*batch_shape,).""" 

167 # E[X^2] = weighted squared bin centers. 

168 weighted_centers_sq = self.bin_probs * (self.bin_centers**2) # shape: (*batch_shape, num_bins) 

169 second_moment = torch.sum(weighted_centers_sq, dim=-1) # shape: (*batch_shape,) 

170 

171 # Var = E[X^2] - E[X]^2 

172 return second_moment - self.mean**2 

173 

174 @property 

175 def support(self) -> constraints.Constraint: 

176 """Support of this distribution. Needs to be limitited to keep the number of bins manageable.""" 

177 return constraints.interval(self.bound_low, self.bound_up) 

178 

179 @property 

180 def arg_constraints(self) -> dict[str, constraints.Constraint]: 

181 """Constraints that should be satisfied by each argument of this distribution. None for this class.""" 

182 return {} 

183 

184 def expand( 

185 self, batch_shape: torch.Size | list[int] | tuple[int, ...], _instance: Distribution | None = None 

186 ) -> "BinnedLogitCDF": 

187 """Expand distribution to new batch shape. This creates a new instance.""" 

188 expanded_logits = self.logits.expand((*torch.Size(batch_shape), self.num_bins)) 

189 return BinnedLogitCDF( 

190 logits=expanded_logits, 

191 bound_low=self.bound_low, 

192 bound_up=self.bound_up, 

193 log_spacing=self.log_spacing, 

194 validate_args=self._validate_args, 

195 ) 

196 

197 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 

198 """Compute log probability density at given values. 

199 

200 Args: 

201 value: Values at which to compute the log PDF. 

202 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

203 

204 Returns: 

205 Log PDF values corresponding to the input values. 

206 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

207 """ 

208 return torch.log(self.prob(value) + 1e-8) # small epsilon for stability 

209 

210 def prob(self, value: torch.Tensor) -> torch.Tensor: 

211 """Compute probability density at given values. 

212 

213 Args: 

214 value: Values at which to compute the PDF. 

215 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

216 

217 Returns: 

218 PDF values corresponding to the input values. 

219 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

220 """ 

221 if self._validate_args: 221 ↛ 224line 221 didn't jump to line 224 because the condition on line 221 was always true

222 self._validate_sample(value) 

223 

224 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

225 

226 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). 

227 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): 

228 value = value.expand(self.batch_shape) 

229 

230 # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the 

231 # index where value would be inserted to maintain sorted order. 

232 # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index. 

233 bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape) 

234 

235 # Clamp to valid range [0, num_bins - 1] to handle edge cases: 

236 # - values below bound_low would give bin_idx = -1 

237 # - values at bound_up would give bin_idx = num_bins 

238 bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1) 

239 

240 # Gather the bin widths and probabilities for the selected bins. 

241 # For bin_widths of shape (num_bins,) we can index directly. 

242 bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape) 

243 

244 # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension. 

245 # Add sample dimensions to bin_probs and expand to match bin_indices shape. 

246 num_sample_dims = len(bin_indices.shape) - len(self.batch_shape) 

247 bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape) 

248 bin_probs_for_gather = bin_probs_for_gather.expand( 

249 *bin_indices.shape, -1 

250 ) # shape: (*sample_shape, *batch_shape, num_bins) 

251 

252 # Gather the selected bin probabilities. 

253 bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) 

254 bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather) 

255 bin_probs_selected = bin_probs_selected.squeeze(-1) 

256 

257 # Compute PDF = probability mass / bin width. 

258 return bin_probs_selected / bin_widths_selected 

259 

260 def cdf(self, value: torch.Tensor) -> torch.Tensor: 

261 """Compute cumulative distribution function at given values. 

262 

263 Args: 

264 value: Values at which to compute the CDF. 

265 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

266 

267 Returns: 

268 CDF values in [0, 1] corresponding to the input values. 

269 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

270 """ 

271 if self._validate_args: 271 ↛ 274line 271 didn't jump to line 274 because the condition on line 271 was always true

272 self._validate_sample(value) 

273 

274 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

275 

276 # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions). 

277 if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape): 

278 value = value.expand(self.batch_shape) 

279 

280 # Use binary search to find how many bin centers are <= value. 

281 # torch.searchsorted with right=True gives us the number of elements <= value. 

282 num_bins_active = torch.searchsorted(self.bin_centers, value, right=True) 

283 

284 # Clamp to valid range [0, num_bins]. 

285 num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape) 

286 

287 # Compute cumulative sum of bin probabilities. 

288 # Prepend 0 for the case where no bins are active. 

289 num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape) 

290 cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins) 

291 cumsum_probs = torch.cat( 

292 [torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs], 

293 dim=-1, 

294 ) # shape: (*batch_shape, num_bins + 1) 

295 

296 # Expand cumsum_probs to match sample dimensions and gather. 

297 cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape) 

298 cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1) 

299 num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1) 

300 cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather) 

301 cdf_values = cdf_values.squeeze(-1) 

302 

303 return cdf_values 

304 

305 def icdf(self, value: torch.Tensor) -> torch.Tensor: 

306 """Compute the inverse CDF, i.e., the quantile function, at the given values. 

307 

308 Args: 

309 value: Values in [0, 1] at which to compute the inverse CDF. 

310 Expected shape: (*sample_shape, *batch_shape) or broadcastable to it. 

311 

312 Returns: 

313 Quantiles in [bound_low, bound_up] corresponding to the input CDF values. 

314 Output shape: same as `value` shape after broadcasting, i.e., (*sample_shape, *batch_shape). 

315 """ 

316 if self._validate_args and not (value >= 0).all() and (value <= 1).all(): 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true

317 raise ValueError("icdf input must be in [0, 1]") 

318 

319 value = value.to(dtype=self.logits.dtype, device=self.logits.device) 

320 

321 # Compute CDF at bin edges. prepend zeros to the cumsum of probabilities as this is always the first edge. 

322 cdf_edges = torch.cat( 

323 [ 

324 torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), 

325 torch.cumsum(self.bin_probs, dim=-1), # shape: (*batch_shape, num_bins) 

326 ], 

327 dim=-1, 

328 ) # shape: (*batch_shape, num_bins + 1) 

329 

330 # Determine number of sample dimensions (dimensions before batch_shape). 

331 num_sample_dims = len(value.shape) - len(self.batch_shape) 

332 

333 # Prepend singleton dimensions for sample_shape to cdf_edges. 

334 # cdf_edges: (*batch_shape, num_bins + 1) -> (*sample_shape, *batch_shape, num_bins + 1) 

335 cdf_edges = cdf_edges.view((1,) * num_sample_dims + cdf_edges.shape) 

336 

337 # Prepend singleton dimensions for both sample_shape and batch_shape. 

338 # bin_edges: (num_bins + 1,) -> (*sample_shape, *batch_shape, num_bins + 1) 

339 bin_edges_expanded = self.bin_edges.view( 

340 (1,) * (num_sample_dims + len(self.batch_shape)) + self.bin_edges.shape 

341 ) 

342 

343 # Add bin dimension to value for comparison. 

344 value_expanded = value.unsqueeze(-1) 

345 

346 # Find bins containing the value: left_cdf <= value < right_cdf. 

347 bin_mask = (cdf_edges[..., :-1] <= value_expanded) & (value_expanded < cdf_edges[..., 1:]) 

348 bin_mask = bin_mask.to(self.logits.dtype) 

349 

350 # Handle edge case where value ≈ 1.0 (use isclose with dtype-appropriate defaults). 

351 value_is_one = torch.isclose(value_expanded, torch.ones_like(value_expanded)) 

352 bin_mask[..., -1] = torch.max(bin_mask[..., -1], value_is_one[..., 0]) # last bin could be selected already 

353 

354 # Selected the correct bin edges using the mask. Summing is essentially selecting here. 

355 # Summing fast and differentiable. 

356 cfd_value_bin_starts = torch.sum(bin_mask * cdf_edges[..., :-1], dim=-1) 

357 cdf_value_bin_ends = torch.sum(bin_mask * cdf_edges[..., 1:], dim=-1) 

358 bin_left_edges = torch.sum(bin_mask * bin_edges_expanded[..., :-1], dim=-1) 

359 bin_right_edges = torch.sum(bin_mask * bin_edges_expanded[..., 1:], dim=-1) 

360 

361 # Avoid division by zero. 

362 bin_width = cdf_value_bin_ends - cfd_value_bin_starts 

363 safe_bin_width = torch.where(bin_width > 1e-8, bin_width, torch.ones_like(bin_width)) 

364 

365 # Linear interpolation within the bin. 

366 alpha = (value - cfd_value_bin_starts) / safe_bin_width 

367 quantiles = bin_left_edges + alpha * (bin_right_edges - bin_left_edges) 

368 

369 return quantiles 

370 

371 @torch.no_grad() 

372 def sample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor: 

373 """Sample from the distribution by passing uniformly random draws from [0, 1] thought the inverse CDF. 

374 

375 Args: 

376 sample_shape: Shape of the samples to draw. 

377 

378 Returns: 

379 Samples of shape (sample_shape + batch_shape), where batch_shape is the batch shape of the distribution. 

380 """ 

381 shape = torch.Size(sample_shape) + self.batch_shape 

382 uniform_samples = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) 

383 return self.icdf(uniform_samples) 

384 

385 def entropy(self) -> torch.Tensor: 

386 r"""Compute differential entropy of the distribution. 

387 

388 Entropy H(X) = -\sum_{x \in \mathcal{X}} p(x) \log( p(x) ) 

389 

390 Note: 

391 Here, we are doing an approximation by treating each bin as a uniform distribution over its width. 

392 """ 

393 bin_probs = self.bin_probs 

394 

395 # Get the PDF values at bin centers. 

396 pdf_values = bin_probs / self.bin_widths # shape: (*batch_shape, num_bins) 

397 

398 # Entropy ≈ -∑ p_i * log(pdf_i) * bin_width_i. 

399 log_pdf = torch.log(pdf_values + 1e-8) # small epsilon for stability 

400 entropy_per_bin = -bin_probs * log_pdf 

401 

402 # Sum over bins to get total entropy. 

403 return torch.sum(entropy_per_bin, dim=-1) 

404 

405 def __repr__(self) -> str: 

406 """String representation of the distribution.""" 

407 return ( 

408 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, " 

409 f"bound_up: {self.bound_up}, log_spacing: {self.log_spacing})" 

410 )