Coverage for binned_cdf/bezier_cdf.py: 99%

142 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-08 12:02 +0000

1import math 

2from typing import Literal 

3 

4import torch 

5from torch.distributions import Distribution, constraints 

6 

7_size = torch.Size() 

8 

9 

10class BezierCDF(Distribution): 

11 r"""A continuous probability distribution parameterized by Bernstein polynomials with custom constraints. 

12 

13 The idea is that the CDF is represented as a Bezier curve, which is a weighted sum of Bernstein basis polynomials, 

14 defined by control points (betas) that are derived from the input logits. 

15 This allows for a smooth, flexible CDF that can capture complex shapes while still being differentiable. 

16 In fact, this formulation is mathematically equivalent to a mixture of Beta distributions, where the mixture 

17 weights are given by the deltas (softmax of the logits) and the Beta components are defined by the control points. 

18 

19 Since we know that any CDF must start at 0 and end at 1, we can enforce these constraints by fixing the first 

20 control point to 0 and the last control point to 1. 

21 

22 The spacing of the control points along the domain-axis ("x-axis") is strictly uniform and determined by the 

23 degree of the Bernstein polynomial, hence, number of input logits. 

24 

25 Note: 

26 Bernstein polynomials converge slowly: the worst-case pointwise approximation error is $O(1/n)$ where $n$ is 

27 the polynomial degree, leading to a standard deviation error of $O(1/\sqrt{n})$. However, for smooth CDFs the 

28 effective rate is better, and Bernstein density estimators achieve the optimal minimax rate (Babu et al., 2002; 

29 Petrone, 1999). This slower convergence is an inherent trade-off for the structural guarantees they provide: 

30 monotonicity, values in $[0, 1]$, non-negative PDF, and an unconstrained parameterization (any real-valued 

31 logits yield a valid distribution). No other polynomial basis offers all of these simultaneously. In practice, 

32 the bias matters less when logits are learned end-to-end via gradient descent, as the optimizer can compensate. 

33 

34 The sharpest peak a degree-n Bernstein polynomial can produce is a single Beta component with 

35 $std \approx 1/(2\sqrt{n})$ in [0,1]-space. Scaled to support range R, the peak std is $R / (2\sqrt{n})$. 

36 """ 

37 

38 has_rsample = True 

39 

40 def __init__( 

41 self, 

42 logits: torch.Tensor, 

43 bound_low: float = -1e3, 

44 bound_up: float = 1e3, 

45 normalization_method: Literal["sigmoid", "softmax"] = "softmax", 

46 validate_args: bool | None = None, 

47 ) -> None: 

48 """Initializer. 

49 

50 Args: 

51 logits: Raw logits for the probabilities before normalization, of shape (*batch_shape, degree). 

52 The logits also determine the degree of the Bernstein polynomial $n$. 

53 bound_low: Lower bound of the distribution support, needs to be finite. 

54 bound_up: Upper bound of the distribution support, needs to be finite. 

55 normalization_method: How to normalize the probabilities. Either "sigmoid" or "softmax". With "sigmoid", 

56 each control point is independently activated, while with "softmax", the control point activations 

57 influence each other. 

58 validate_args: Whether to validate arguments. Carried over to keep the interface with the base class. 

59 """ 

60 self.logits = logits 

61 self.bound_low = bound_low 

62 self.bound_up = bound_up 

63 self.normalization_method = normalization_method 

64 

65 # Precompute binomial coefficients, and store them on the same device as logits. 

66 self._binom_coeffs_cdf, self._binom_coeffs_pdf = self._compute_binomial_coefficients() 

67 

68 # Precompute log-space binomial coefficients for numerically stable log_prob. 

69 self._log_binom_coeffs_pdf = self._binom_coeffs_pdf.log() 

70 

71 # Calculate parameters (deltas and betas). 

72 self._deltas, self._betas, self._log_deltas = self._compute_deltas_and_betas() 

73 

74 # Determine batch shape based on the logits. The event shape is scalar since this is a univariate distribution. 

75 super().__init__(batch_shape=logits.shape[:-1], event_shape=torch.Size([]), validate_args=validate_args) 

76 

77 def __repr__(self) -> str: 

78 """String representation of the distribution.""" 

79 return ( 

80 f"{self.__class__.__name__}(logits_shape: {self.logits.shape}, bound_low: {self.bound_low}, " 

81 f"bound_up: {self.bound_up}, normalization_method: {self.normalization_method})" 

82 ) 

83 

84 def _compute_binomial_coefficients(self) -> tuple[torch.Tensor, torch.Tensor]: 

85 """Compute the binomial coefficients for the CDF and PDF based on the degree of the Bernstein polynomial. 

86 

87 comb(n, k) = n! / (k! * (n-k)!) is the binomial coefficient, which counts the number of ways to choose k 

88 elements from a set of n elements. 

89 

90 Returns: 

91 coeffs_cdf: Binomial coefficients for the CDF, of shape (degree + 1,) 

92 coeffs_pdf: Binomial coefficients for the PDF, of shape (degree,) 

93 """ 

94 coeffs_cdf = torch.tensor( 

95 [math.comb(self.degree, i) for i in range(self.degree + 1)], 

96 device=self.logits.device, 

97 dtype=self.logits.dtype, 

98 ) 

99 

100 coeffs_pdf = torch.tensor( 

101 [math.comb(self.degree - 1, i) for i in range(self.degree)], 

102 device=self.logits.device, 

103 dtype=self.logits.dtype, 

104 ) 

105 

106 # Check if any of the binomial coefficients became infinite. 

107 if torch.isinf(coeffs_cdf).any() or torch.isinf(coeffs_pdf).any(): 

108 raise ValueError( 

109 f"Binomial coefficients became infinite for degree {self.degree}. " 

110 "Consider reducing the (last) dimension of the logits, leading to lower degree polynomial." 

111 ) 

112 

113 return coeffs_cdf, coeffs_pdf 

114 

115 def _compute_deltas_and_betas(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

116 r"""Compute the deltas (Beta mixture component weights) and betas (control points) for the Bezier curve based 

117 on the given logits. 

118 

119 The deltas are the forward differences of the betas, i.e., $ \Delta_i = \beta_{i + 1} - \beta_i $. 

120 

121 Returns: 

122 deltas: Weights of the Beta components in the mixture, of shape (*batch_shape, degree) 

123 betas: Control points of the Bezier curve, of shape (*batch_shape, degree + 1) 

124 log_deltas: Log of the deltas, computed in a numerically stable way, of shape (*batch_shape, degree) 

125 """ 

126 # The deltas are the steps themselves (forward differences of betas). 

127 if self.normalization_method == "softmax": 

128 deltas = torch.softmax(self.logits, dim=-1) # shape: (*batch_shape, degree) 

129 log_deltas = torch.log_softmax(self.logits, dim=-1) 

130 

131 elif self.normalization_method == "sigmoid": 131 ↛ 145line 131 didn't jump to line 145 because the condition on line 131 was always true

132 raw_deltas = torch.sigmoid(self.logits) 

133 sum_deltas = raw_deltas.sum(dim=-1, keepdim=True) 

134 

135 # Prevent division by zero in the rare case where all logits are massively negative. 

136 eps = torch.finfo(raw_deltas.dtype).eps 

137 sum_deltas = sum_deltas.clamp_min(eps) 

138 

139 deltas = raw_deltas / sum_deltas # shape: (*batch_shape, degree) 

140 

141 # log(Delta) = log(sigmoid(x) / sum(sigmoid(x))) = logsigmoid(x) - log(sum(sigmoid(x))). 

142 log_deltas = torch.nn.functional.logsigmoid(self.logits) - sum_deltas.log() 

143 

144 else: 

145 raise ValueError(f"Unknown normalization method: {self.normalization_method}") 

146 

147 # Pad with zeros and ones to enforce the CDF boundary conditions: 

148 # betas = [0, beta_1, ..., beta_{n-1}, beta_n = 1] 

149 zeros = torch.zeros(*deltas.shape[:-1], 1, device=deltas.device, dtype=deltas.dtype) # shape: (*batch_shape, 1) 

150 inner_betas = torch.cumsum(deltas, dim=-1)[..., :-1] 

151 ones = torch.ones(*deltas.shape[:-1], 1, device=deltas.device, dtype=deltas.dtype) 

152 betas = torch.cat([zeros, inner_betas, ones], dim=-1) 

153 

154 return deltas, betas, log_deltas 

155 

156 def _map_to_t_space(self, value: torch.Tensor) -> torch.Tensor: 

157 r"""Map values from the original $X$ space to the $T$ space $[0, 1]$ using the bounds.""" 

158 return torch.clamp((value - self.bound_low) / self.support_range, 0, 1) 

159 

160 def _map_to_x_space(self, t: torch.Tensor) -> torch.Tensor: 

161 r"""Map values from the $T$ space $[0, 1]$ back to the original $X$ space using the bounds.""" 

162 return t * self.support_range + self.bound_low 

163 

164 @property 

165 def support(self) -> constraints.Constraint: 

166 """Support of this distribution.""" 

167 return constraints.interval(self.bound_low, self.bound_up) 

168 

169 @property 

170 def support_range(self) -> float: 

171 """Range of the support, i.e., upper bound - lower bound.""" 

172 return self.bound_up - self.bound_low 

173 

174 @property 

175 def arg_constraints(self) -> dict[str, constraints.Constraint]: 

176 """Constraints that should be satisfied by each argument of this distribution. None for this class.""" 

177 return {"logits": constraints.real} 

178 

179 @property 

180 def degree(self) -> int: 

181 r"""Get the degree $n$ of the Bernstein polynomial based on the number of logits. 

182 

183 For a Bernstein polynomial of degree $n$, there are $n + 1$ control points (betas) and $n$ weights (deltas). 

184 """ 

185 return self.logits.shape[-1] 

186 

187 @property 

188 def mean(self) -> torch.Tensor: 

189 r"""Compute mean of the distribution, i.e., the weighted average of the control points. 

190 

191 We transform the random variable $X$ to $T$ in [0, 1] by scaling and shifting according to the bounds. 

192 Then, the mean of $T$ can be computed as 

193 

194 $$ E[T] = \sum_{i=0}^{n-1} \Delta_i \frac{i+1}{n+1} $$ 

195 

196 where $\Delta_i$ is the weight of the $i$-th control point, and $n$ is the degree of the Bernstein polynomial. 

197 We can then get the mean by rescaling $E[T]$ back to the original support: 

198 

199 $$ E[X] = (U - L) E[T] + L $$ 

200 

201 where $L$ and $U$ are the lower and upper bounds of the distribution support, respectively. 

202 

203 Note: 

204 This method uses the exact Beta mixture formula. 

205 

206 Returns: 

207 Tensor of shape (*batch_shape,). 

208 """ 

209 i = torch.arange(self.degree, device=self._deltas.device, dtype=self._deltas.dtype) # shape: (degree,) 

210 e_t = torch.sum(self._deltas * (i + 1) / (self.degree + 1), dim=-1) 

211 

212 return self._map_to_x_space(e_t) 

213 

214 @property 

215 def variance(self) -> torch.Tensor: 

216 r"""Compute variance of the distribution. 

217 

218 We transform the random variable $X$ to $T$ in [0, 1] by scaling and shifting according to the bounds. 

219 Then, the variance of $T$ can be computed as 

220 

221 $$ Var[T] = E[T^2] - (E[T])^2 $$ 

222 

223 with 

224 

225 $$ E[T^2] = \sum_{i=0}^{n-1} \Delta_i \frac{(i+1)(i+2)}{(n+1)(n+2)} $$ 

226 

227 where $\Delta_i$ is the weight of the $i$-th control point, and $n$ is the degree of the Bernstein polynomial. 

228 We can then get the variance by rescaling $Var[T]$ back to the original support: 

229 

230 $$ Var[X] = (U - L)^2 Var[T] $$ 

231 

232 Note: 

233 This method uses the exact Beta mixture formula. 

234 

235 Returns: 

236 Tensor of shape (*batch_shape,). 

237 """ 

238 i = torch.arange(self.degree, device=self._deltas.device, dtype=self._deltas.dtype) # shape: (degree,) 

239 e_t = torch.sum(self._deltas * (i + 1) / (self.degree + 1), dim=-1) 

240 e_t2 = torch.sum(self._deltas * ((i + 1) * (i + 2)) / ((self.degree + 1) * (self.degree + 2)), dim=-1) 

241 var_t = e_t2 - e_t**2 

242 

243 return self.support_range**2 * var_t 

244 

245 def _eval_bezier_curve( 

246 self, 

247 t: torch.Tensor, 

248 weights: torch.Tensor, 

249 binom_coeffs: torch.Tensor, 

250 ) -> torch.Tensor: 

251 r"""Evaluates a Bezier curve (a Bernstein polynomial) in the $T \in [0, 1]$ space. 

252 

253 This method computes the weighted sum of Bernstein basis polynomials. Let $d$ be the degree of the polynomial 

254 being evaluated (either $n$ or $n+1$). Each basis polynomial is defined as: 

255 

256 $$ B_{i, d}(t) = \binom{d}{i} t^i (1-t)^{d-i} $$ 

257 

258 The polynomial's value $p(t)$ is computed as: 

259 

260 $$ p(t) = \sum_{i=0}^{d} w_i B_{i, d}(t) $$ 

261 

262 where $w_i$ are the weights (either betas for the CDF or deltas for the PDF). 

263 

264 Args: 

265 t: Normalized input values in [0, 1]. 

266 Expected shape: (*sample_shape, *batch_shape). 

267 weights: The coefficients for the basis polynomials. 

268 Expected shape: (*batch_shape, d + 1). 

269 binom_coeffs: Precomputed binomial coefficients corresponding to the polynomial's degree. 

270 Expected shape: (d + 1,). 

271 

272 Returns: 

273 The evaluated polynomial values. 

274 Output shape: (*sample_shape, *batch_shape) 

275 """ 

276 # Get n which can be != self.degree as we use this method for both CDF and PDF which have different degrees. 

277 nun_coeffs = binom_coeffs.shape[0] 

278 

279 # Create a tensor of indices matching the number of basis polynomials. 

280 i = torch.arange(nun_coeffs, device=t.device, dtype=t.dtype) 

281 

282 # Add an empty dimension to t for broadcasting, resulting in shape: (*sample_shape, *batch_shape, 1). 

283 t_expanded = t.unsqueeze(-1) 

284 

285 # Compute the entire basis in one shot. 

286 # PyTorch broadcasts the shapes to shape (*sample_shape, *batch_shape, degree). 

287 basis = binom_coeffs * (t_expanded**i) * ((1 - t_expanded) ** (nun_coeffs - 1 - i)) 

288 

289 # Multiply by weights and sum across the final dimension, resulting in shape (*sample_shape, *batch_shape). 

290 return torch.sum(weights * basis, dim=-1) 

291 

292 def cdf(self, value: torch.Tensor) -> torch.Tensor: 

293 """Compute cumulative distribution function at given values. 

294 

295 Args: 

296 value: Values at which to compute the CDF. Expected shape: (*sample_shape, *batch_shape). 

297 

298 Returns: 

299 CDF values in [0, 1] corresponding to the input values. Output shape: same as `value` argument. 

300 """ 

301 x = value.to(device=self.logits.device, dtype=self.logits.dtype) 

302 

303 # Map X in [bound_low, bound_up] to T in [0, 1]. 

304 t = self._map_to_t_space(x) 

305 

306 # Construct and evaluate the Bezier curve in T space. 

307 return self._eval_bezier_curve(t, weights=self._betas, binom_coeffs=self._binom_coeffs_cdf) 

308 

309 def prob(self, value: torch.Tensor) -> torch.Tensor: 

310 """Compute probability density at given values. 

311 

312 Args: 

313 value: Values at which to compute the PDF. Expected shape: (*sample_shape, *batch_shape). 

314 

315 Returns: 

316 PDF values corresponding to the input values. Output shape: same as `value` argument. 

317 """ 

318 x = value.to(device=self.logits.device, dtype=self.logits.dtype) 

319 

320 # Map X in [bound_low, bound_up] to T in [0, 1]. 

321 t = self._map_to_t_space(x) 

322 

323 # Construct and evaluate the Bezier curve in T space. 

324 val = self._eval_bezier_curve(t, weights=self._deltas, binom_coeffs=self._binom_coeffs_pdf) 

325 

326 # Apply the chain rule: dt/dx = 1 / (U - L). 

327 pdf_val = val * self.degree / self.support_range 

328 

329 # Mask out values outside [bound_low, bound_up]. 

330 mask = (value >= self.bound_low) & (value <= self.bound_up) 

331 return torch.where(mask, pdf_val, torch.zeros_like(pdf_val)) 

332 

333 def log_prob(self, value: torch.Tensor) -> torch.Tensor: 

334 r"""Compute the log-probability density at given values, entirely in log-space for numerical stability. 

335 

336 Uses the identity 

337 

338 $$ 

339 \log p(x) = \log \frac{n}{U - L} 

340 + \text{logsumexp}_i\!\Big(\log \Delta_i + \log \binom{n-1}{i} 

341 + i \log t + (n-1-i) \log(1-t)\Big) 

342 $$ 

343 

344 where $t = (x - L) / (U - L)$ is the normalized input. Every term is computed in log-space, 

345 avoiding the numerically problematic ``log(polynomial + eps)`` path. 

346 

347 Args: 

348 value: Values at which to compute the log-PDF. Expected shape: (*sample_shape, *batch_shape). 

349 

350 Returns: 

351 Log-PDF values corresponding to the input values. Output shape: same as `value` argument. 

352 """ 

353 x = value.to(device=self.logits.device, dtype=self.logits.dtype) 

354 t = self._map_to_t_space(x) 

355 

356 eps = torch.finfo(t.dtype).eps 

357 n = self.degree 

358 

359 # Clamp t away from exact 0/1 to avoid log(0). 

360 t_safe = t.clamp(min=eps, max=1 - eps) 

361 

362 # Indices for the Bernstein basis: i = 0, ..., n-1. 

363 i = torch.arange(n, device=t.device, dtype=t.dtype) 

364 

365 # Expand t for broadcasting: (*sample_shape, *batch_shape, 1). 

366 log_t = t_safe.unsqueeze(-1) # will broadcast with i 

367 

368 # Log of each Bernstein basis term: log(binom) + i*log(t) + (n-1-i)*log(1-t). 

369 log_basis = self._log_binom_coeffs_pdf + i * log_t.log() + (n - 1 - i) * (1 - log_t).log() 

370 

371 # Log of each weighted term: log(delta_i) + log(basis_i). 

372 # _log_deltas shape: (*batch_shape, n), log_basis shape: (*sample_shape, *batch_shape, n). 

373 log_terms = self._log_deltas + log_basis 

374 

375 # Sum via logsumexp over the last dimension. 

376 log_bezier = torch.logsumexp(log_terms, dim=-1) 

377 

378 # Apply the chain rule: log(n / (U - L)) + log(bezier). 

379 log_pdf = math.log(n / self.support_range) + log_bezier 

380 

381 # Mask values outside the support. 

382 mask = (value >= self.bound_low) & (value <= self.bound_up) 

383 return torch.where(mask, log_pdf, torch.full_like(log_pdf, -math.inf)) 

384 

385 def entropy(self, num_quadrature_points: int = 251) -> torch.Tensor: 

386 r"""Compute differential entropy of the distribution via numerical quadrature. 

387 

388 $$ H(X) = -\int_{L}^{U} p(x) \log p(x) \, dx $$ 

389 

390 where $L$ and $U$ are the lower and upper bounds of the distribution support, respectively. 

391 

392 Args: 

393 num_quadrature_points: Number of points for the trapezoidal rule approximation. 

394 

395 Returns: 

396 Tensor of shape (*batch_shape,). 

397 """ 

398 # Create quadrature points over the support. 

399 x = torch.linspace( 

400 self.bound_low, self.bound_up, num_quadrature_points, device=self.logits.device, dtype=self.logits.dtype 

401 ) 

402 

403 # For batched distributions, expand quadrature points to shape (num_quadrature_points, *batch_shape) 

404 # so prob/log_prob receive values with explicit batch dimensions. 

405 x_eval = x.reshape(num_quadrature_points, *([1] * len(self.batch_shape))) 

406 x_eval = x_eval.expand(num_quadrature_points, *self.batch_shape) 

407 

408 # Evaluate PDF at quadrature points. 

409 pdf_val = self.prob(x_eval) # shape: (num_quadrature_points, *batch_shape) 

410 

411 # Compute the integrand: -p(x) * log(p(x)), with epsilon for stability. 

412 eps = torch.finfo(pdf_val.dtype).eps 

413 log_pdf = torch.log(pdf_val + 2 * eps) 

414 integrand = -pdf_val * log_pdf # shape: (num_quadrature_points, *batch_shape) 

415 

416 # Integrate using the trapezoidal rule. 

417 return torch.trapezoid(integrand, x, dim=0) 

418 

419 def icdf( 

420 self, 

421 value: torch.Tensor, 

422 num_iter: int = 8, 

423 use_newton: bool = True, 

424 newton_damping: float = 0.9, 

425 convergence_eps_factor: float = 20.0, 

426 ) -> torch.Tensor: 

427 r"""Compute the inverse CDF, i.e., the quantile function, at the given values. 

428 

429 Two solvers are available for inverting $ F(x) - q = 0 $: 

430 

431 **Newton's method** uses the PDF as the exact derivative of the CDF and iterates 

432 

433 $$ x_{k+1} = x_k - \alpha \frac{F(x_k) - q}{f(x_k)} $$ 

434 

435 where $F(x)$ is the CDF, $f(x)$ is the PDF, $q$ is the target quantile in [0, 1], 

436 and $\alpha \in (0, 1]$ is a damping factor that shrinks each Newton step to improve robustness. 

437 A bracket $[L_k, U_k]$ is maintained alongside: whenever $F(x_k) < q$ the lower bound tightens, 

438 otherwise the upper bound tightens. If the Newton step would leave the bracket, a bisection 

439 step is used instead, guaranteeing monotonic bracket contraction and preventing oscillation. 

440 The loop exits early once all elements satisfy $|F(x) - q| < \epsilon$. 

441 

442 **Bisection** halves the search interval each iteration, gaining ~1 bit of precision per step. 

443 

444 Args: 

445 value: Values in [0, 1] at which to compute the inverse CDF. Expected shape: (*sample_shape, *batch_shape). 

446 num_iter: Maximum number of solver iterations. Newton typically converges undamped in ~6-7 iterations; 

447 bisection needs ~15-20 for full float32 precision. 

448 use_newton: If True, use Newton's method. If False, use pure bisection. 

449 newton_damping: Damping factor in (0, 1] applied to the Newton step. A value of 1.0 gives the 

450 full Newton step (quadratic convergence), while smaller values improve robustness 

451 at the cost of slower convergence. 

452 convergence_eps_factor: The factor multiplied by machine epsilon to determine the convergence criterion. 

453 

454 Returns: 

455 Quantiles in [bound_low, bound_up] corresponding to the input CDF values. 

456 Output shape: same as `value` argument. 

457 """ 

458 q = value.to(device=self.logits.device, dtype=self.logits.dtype) 

459 eps = torch.finfo(q.dtype).eps 

460 

461 # Ensure target probability value is strictly in [0, 1]. 

462 q = torch.clamp(q, 0.0, 1.0) 

463 

464 # Start from the midpoint of the support. 

465 mid = torch.full_like(q, (self.bound_low + self.bound_up) / 2) 

466 low = torch.full_like(q, self.bound_low) 

467 high = torch.full_like(q, self.bound_up) 

468 

469 for _ in range(num_iter): 

470 cdf_mid = self.cdf(mid) 

471 

472 # Early stop when all elements have converged. 

473 abs_deviation = (cdf_mid - q).abs().max() 

474 if abs_deviation < convergence_eps_factor * eps: 

475 break 

476 

477 # Tighten the bracket based on CDF evaluation. 

478 low = torch.where(cdf_mid < q, mid, low) 

479 high = torch.where(cdf_mid >= q, mid, high) 

480 bisect_mid = (low + high) / 2 

481 

482 if use_newton: 

483 # Newton step: x_{k+1} = x_k - (F(x_k) - q) / f(x_k). 

484 pdf_mid = self.prob(mid) 

485 newton_mid = mid - newton_damping * (cdf_mid - q) / pdf_mid.clamp_min(2 * eps) 

486 

487 # Use Newton step if it stays within the bracket, otherwise fall back to bisection. 

488 in_bracket = (newton_mid >= low) & (newton_mid <= high) 

489 mid = torch.where(in_bracket, input=newton_mid, other=bisect_mid) 

490 

491 else: 

492 mid = bisect_mid 

493 

494 return mid 

495 

496 def rsample(self, sample_shape: torch.Size | list[int] | tuple[int, ...] = _size) -> torch.Tensor: 

497 """Draws reparameterized samples from the distribution, and allows gradients to flow backawards. 

498 

499 Args: 

500 sample_shape: Desired shape of the samples to be drawn. Default is empty, which means one sample per batch element. 

501 

502 Returns: 

503 Samples drawn from the distribution, with shape (*sample_shape, *batch_shape). 

504 """ 

505 # Determine the final shape of the output tensor. 

506 shape = self._extended_shape(sample_shape) 

507 

508 # Sample uniform noise, u ~ U(0, 1). 

509 u = torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) 

510 

511 # Find the root (the sample x) without tracking gradients for the loop. 

512 with torch.no_grad(): 

513 x_root = self.icdf(u) 

514 

515 # Apply the implicit differentiation trick, i.e., evaluate CDF to connect the parameters to the 

516 # computational graph. 

517 cdf_val = self.cdf(x_root) 

518 

519 # Evaluate PDF and detach it to act as the constant denominator. 

520 pdf_val = self.prob(x_root).detach() 

521 

522 # Clamp PDF to avoid division by zero near the boundaries where slope is 0. This limits the gradients. 

523 eps = torch.finfo(pdf_val.dtype).eps 

524 pdf_val = pdf_val.clamp_min(2 * eps) 

525 

526 # Attach the exact reparameterized gradient. 

527 x = x_root + (u - cdf_val) / pdf_val 

528 

529 # Clamp to the support to prevent the implicit-differentiation correction from pushing samples 

530 # slightly past the domain boundaries when the CDF is very flat near the bounds. 

531 return x.clamp(min=self.bound_low, max=self.bound_up)