plotting

categorical

draw_categorical(plot_type: str, ax: Axes, data: Union[list, ndarray, Tensor, DataFrame], x_label: Optional[Union[str, Sequence[str]]], y_label: Optional[str], vline_level: Optional[float] = None, vline_label: str = 'approx. solved', palette=None, title: Optional[str] = None, show_legend: bool = True, legend_kwargs: Optional[dict] = None, plot_kwargs: Optional[dict] = None) Figure[source]

Create a box or violin plot for a list of data arrays or a pandas DataFrame. The plot is neither shown nor saved.

If you want to order the 4th element to the 2nd position in terms of colors use

palette.insert(1, palette.pop(3))

Note

If you want to have a tight layout, it is best to pass axes of a figure with tight_layout=True or constrained_layout=True.

Parameters:
  • plot_type – tye of categorical plot, pass box or violin

  • ax – axis of the figure to plot on

  • data – list of data sets to plot as separate boxes

  • x_label – labels for the categories on the x-axis, if data is not given as a DataFrame

  • y_label – label for the y-axis, pass None to set no label

  • vline_level – if not None (default) add a vertical line at the given level

  • vline_label – label for the vertical line

  • palette – seaborn color palette, pass None to use the default palette

  • show_legend – if True the legend is shown, useful when handling multiple subplots

  • title – title displayed above the figure, set to None to suppress the title

  • legend_kwargs – keyword arguments forwarded to pyplot’s legend() function, e.g. loc=’best’

  • plot_kwargs – keyword arguments forwarded to seaborn’s boxplot() or violinplot() function

Returns:

handle to the resulting figure

curve

draw_curve(plot_type: str, ax: Axes, data: DataFrame, x_grid: Union[list, ndarray, Tensor], x_label: Optional[Union[str, Sequence[str]]] = None, y_label: Optional[str] = None, curve_label: Optional[str] = None, area_label: Optional[str] = '', vline_level: Optional[float] = None, vline_label: str = 'approx. solved', title: Optional[str] = None, show_legend: bool = True, plot_kwargs: Optional[dict] = None, legend_kwargs: Optional[dict] = None) Figure[source]

Create a box or violin plot for a list of data arrays or a pandas DataFrame. The plot is neither shown nor saved.

Note

If you want to have a tight layout, it is best to pass axes of a figure with tight_layout=True or constrained_layout=True.

If you want to order the 4th element to the 2nd position in terms of colors use .. code-block:: python

palette.insert(1, palette.pop(3))

Parameters:
  • plot_type – tye of 1-dim plot: mean_std, min_mean_max, or ci_on_mean

  • ax – axis of the figure to plot on

  • data – pandas DataFrame containing the columns mean, std, min, and max depending on the plot_type

  • x_grid – values to plot the data over, e.g. time

  • x_label – labels for the categories on the x-axis, if data is not given as a DataFrame

  • y_label – label for the y-axis, pass None to set no label

  • curve_label – label of the (1-dim) curve, pass None for no label

  • area_label – label of the (transparent) area, pass None for no label and “” for the default label

  • vline_level – if not None (default) add a vertical line at the given level

  • vline_label – label for the vertical line

  • show_legend – if True the legend is shown, useful when handling multiple subplots

  • title – title displayed above the figure, set to None to suppress the title

  • plot_kwargs – keyword arguments forwarded to the plotting` functions

  • legend_kwargs – keyword arguments forwarded to pyplot’s legend() function, e.g. loc=’best’

Returns:

handle to the resulting figure

draw_curve_from_data(plot_type: str, ax: Axes, data: Union[list, ndarray, Tensor, DataFrame], x_grid: Union[list, ndarray, Tensor], ax_calc: int, x_label: Optional[Union[str, Sequence[str]]] = None, y_label: Optional[str] = None, curve_label: Optional[str] = None, area_label: Optional[str] = '', vline_level: Optional[float] = None, vline_label: str = 'approx. solved', title: Optional[str] = None, show_legend: bool = True, cmp_kwargs: Optional[dict] = None, plot_kwargs: Optional[dict] = None, legend_kwargs: Optional[dict] = None) Figure[source]

Create a box or violin plot for a list of data arrays or a pandas DataFrame. The plot is neither shown nor saved.

Note

If you want to have a tight layout, it is best to pass axes of a figure with tight_layout=True or constrained_layout=True.

If you want to order the 4th element to the 2nd position in terms of colors use .. code-block:: python

palette.insert(1, palette.pop(3))

Parameters:
  • plot_type – tye of 1-dim plot: mean_std, min_mean_max, or ci_on_mean

  • ax – axis of the figure to plot on

  • data – data to plot,me.g. a time series

  • x_grid – values to plot the data over, e.g. time

  • ax_calc – axis of the data array to calculate the mean, min and max, or std over

  • x_label – labels for the categories on the x-axis, if data is not given as a DataFrame

  • y_label – label for the y-axis, pass None to set no label

  • curve_label – label of the (1-dim) curve, pass None for no label

  • area_label – label of the (transparent) area, pass None for no label and “” for the default label

  • vline_level – if not None (default) add a vertical line at the given level

  • vline_label – label for the vertical line

  • show_legend – if True the legend is shown, useful when handling multiple subplots

  • title – title displayed above the figure, set to None to suppress the title

  • cmp_kwargs – keyword arguments forwarded to functions computing the statistics of interest

  • plot_kwargs – keyword arguments forwarded to the plotting` functions

  • legend_kwargs – keyword arguments forwarded to pyplot’s legend() function, e.g. loc=’best’

Returns:

handle to the resulting figure

draw_dts(dts_policy: ndarray, dts_step: ndarray, dts_remainder: ndarray, y_top_lim: Optional[float] = None)[source]

Create a figure and draw the time intervals \(\Delta_t\) of various parts of one time step.

Parameters:
  • dts_policy – time it took to compute the policy’s action

  • dts_step – time it took to perform the

  • dts_remainder – time it took to execute all remaining commands (e.g. soring the data)

  • y_top_lim – upper bound for the y-axis in ms, no limit by default

distribution

gaussian_process

heatmap

draw_heatmap(data: ~pandas.core.frame.DataFrame, ax_hm: ~typing.Optional[~matplotlib.axes._axes.Axes] = None, cmap: ~typing.Optional[~matplotlib.colors.Colormap] = None, norm: ~typing.Optional[~matplotlib.colors.Normalize] = <matplotlib.colors.Normalize object>, annotate: bool = True, annotation_valfmt: ~typing.Optional[str] = '{x:.0f}', add_cbar: bool = True, separate_cbar: bool = False, ax_cb: ~typing.Optional[~matplotlib.axes._axes.Axes] = None, cbar_label: ~typing.Optional[str] = None, cbar_orientation: ~typing.Optional[str] = 'vertical', use_index_labels: bool = False, x_label: ~typing.Optional[str] = None, y_label: ~typing.Optional[str] = None, fig_canvas_title: ~typing.Optional[str] = None, fig_size: ~typing.Optional[tuple] = (8, 6), tick_label_prec: ~typing.Optional[int] = 3, xtick_label_prec: ~typing.Optional[int] = None, ytick_label_prec: ~typing.Optional[int] = None, num_major_ticks_hm: ~typing.Optional[int] = None, num_major_ticks_cb: ~typing.Optional[int] = None) -> (<class 'matplotlib.figure.Figure'>, <class 'matplotlib.figure.Figure'>)[source]

Plot a 2D heat map from a 2D pandas.DataFrame using pyplot. The data frame should have exactly one column index level and one row index level. These will automatically become the axis ticks. It is assumed that the data is equally spaced.

Note

If you want to have a tight layout, it is best to pass axes of a figure with tight_layout=True or constrained_layout=True.

Parameters:
  • data – 2D pandas DataFrame

  • ax_hm – axis to draw the heat map onto, if None a new figure is created

  • cmap – colormap passed to imshow()

  • norm – colormap normalizer passed to imshow()

  • annotate – select if the heat map should be annotated

  • annotation_valfmt – format of the annotations inside the heat map, irrelevant if annotate = False

  • add_cbar – if True, add a color bar in the same figure

  • separate_cbar – if True, the color bar is added in a seperate figure

  • ax_cb – axis to draw the color bar onto, if None a new figure is created

  • cbar_label – label for the color bar

  • cbar_orientation – orientation of the color bar

  • use_index_labels – flag if index names from the pandas DataFrame are used as labels for the x- and y-axis. This can can be overridden by x_label and y_label

  • x_label – label for the x axis

  • y_label – label for the y axis

  • fig_canvas_title – window title for the heat map plot, no title by default

  • fig_size – width and height of the figure in inches

  • tick_label_prec – floating point precision of the x- and y-axis labels This can be overwritten xtick_label_prec and ytick_label_prec

  • xtick_label_prec – floating point precision of the x-axis labels, set None for default behavior

  • ytick_label_prec – floating point precision of the y-axis labels, set None for default behavior

  • num_major_ticks_hm – number of major axis ticks for the heat map, set None for default behavior

  • num_major_ticks_cb – number of major axis ticks for the color bar, set None for default behavior

Returns:

handles to the heat map and the color bar figures (None if not existent)

live_update

class LiveFigureManager(file_path: str, data_loader: Callable[[str], Any], args, update_interval: int = 3)[source]

Bases: object

Manages multiple matplotlib figures and refreshes them when the input file changes. It also ensures that if you close a figure, it does not reappear on the next update. If all figures are closed, the update loop is stopped.

Constructor

Parameters:
  • file_path – name of file to load updates from

  • data_loader – called to load the file contents into some internal representation like a pandas DataFrame

  • args – parsed command line arguments

  • update_interval – time to wait between figure updates [s]

figure(title: Optional[str] = None)[source]

Decorator to define a figure update function. Every marked function will be called when the file changes to visualize the updated data.

Usage:

@lfm.figure('A figure')
def a_figure(fig, data, args):
    ax = fig.add_subplot(111)
    ax.plot(data[...])
Parameters:

title – figure title

Returns:

decorator for the plotting function

spin()[source]

Run the plot update loop.

policy_parameters

draw_policy_params(policy: Policy, env_spec: EnvSpec, cmap_name: str = 'RdBu', ax_hm: Optional[Axes] = None, annotate: bool = True, annotation_valfmt: str = '{x:.2f}', cbar_label: str = '', x_label: Optional[str] = None, y_label: Optional[str] = None) Figure[source]

Plot the weights and biases as images, and a color bar.

Note

If you want to have a tight layout, it is best to pass axes of a figure with tight_layout=True or constrained_layout=True.

Parameters:
  • policy – policy to visualize

  • env_spec – environment specification

  • cmap_name – name of the color map, e.g. ‘inferno’, ‘RdBu’, or ‘viridis’

  • ax_hm – axis to draw the heat map onto, if equal to None a new figure is opened

  • annotate – select if the heat map should be annotated

  • annotation_valfmt – format of the annotations inside the heat map, irrelevant if annotate = False

  • cbar_label – label for the color bar

  • x_label – label for the x axis

  • y_label – label for the y axis

Returns:

handles to figures

rollout_based

plot_actions(ro: StepSequence, env: Env)[source]

Plot all action trajectories of the given rollout.

Parameters:
  • ro – input rollout

  • env – environment (used for getting the clipped action values)

plot_features(ro: StepSequence, policy: Policy)[source]

Plot all features given the policy and the observation trajectories.

Parameters:
  • policy – linear policy used during the rollout

  • ro – input rollout

plot_mean_std_across_rollouts(rollouts: Sequence[StepSequence], idcs_obs: Optional[Sequence[int]] = None, idcs_act: Optional[Sequence[int]] = None, show_applied_actions: bool = True)[source]

Plot the mean and standard deviation across a selection of rollouts.

Parameters:
  • rollouts – list of rollouts, they can be of unequal length but are assumed to be from the same type of env

  • idcs_obs – indices of the observations to process and plot, pass None to select all

  • idcs_act – indices of the actions to process and plot, pass None to select all

  • show_applied_actions – if True show the actions applied to the environment insead of the commanded ones

plot_observations(ro: StepSequence, idcs_sel: Optional[Sequence[int]] = None)[source]

Plot all observation trajectories of the given rollout.

Parameters:
  • ro – input rollout

  • idcs_sel – indices of the selected selected observations, if None plot all

plot_observations_actions_rewards(ro: StepSequence)[source]

Plot all observation, action, and reward trajectories of the given rollout.

Parameters:

ro – input rollout

plot_potentials(ro: StepSequence, layout: str = 'joint')[source]

Plot the trajectories specific to a potential-based policy.

Parameters:
  • ro – input rollout

  • layout – group jointly (default), or create a separate sub-figure for each plot

plot_rewards(ro: StepSequence)[source]

Plot the reward trajectories of the given rollout.

Parameters:

ro – input rollout

plot_rollouts_segment_wise(plot_type: str, segments_ground_truth: List[List[StepSequence]], segments_multiple_envs: List[List[List[StepSequence]]], segments_nominal: List[List[StepSequence]], use_rec_str: bool, idx_iter: int, idx_round: int, state_labels: Optional[Iterable[str]] = None, act_labels: Optional[Iterable[str]] = None, x_limits: Optional[Tuple[int]] = None, plot_act: bool = False, data_field: str = 'states', cmap_samples: Optional[Colormap] = None, save_dir: Optional[PathLike] = None, file_format: Iterable[str] = ('pdf', 'pgf', 'png')) List[Figure][source]

Plot the different rollouts in separate figures and the different state dimensions along the columns.

Parameters:
  • plot_type – type of plot, pass “samples” to plot the rollouts of the most likely domain parameters as individual lines, or pass “confidence” to plot the most likely one, and the mean \(\pm\) 1 std

  • segments_ground_truth – list of lists containing rollout segments from the ground truth environment

  • segments_multiple_envs – list of lists of lists containing rollout segments from different environment instances, e.g. samples from a posterior coming from NDPR

  • segments_nominal – list of lists containing rollout segments from the nominal environment

  • use_rec_strTrue if pre-recorded actions have been used to generate the rollouts

  • idx_iter – selected iteration

  • idx_round – selected round

  • state_labels – y-axes labels to for the state trajectories, no label by default

  • act_labels – y-axes labels to for the action trajectories, no label by default

  • x_limits – tuple containing the lower and upper limits for the x-axis

  • plot_act – if True, also plot the actions

  • data_field – data field of the rollout, e.g. “states” or “observations”

  • cmap_samples – color map for the trajectories resulting from different domain parameter samples

  • save_dir – if not None create a subfolder plots in save_dir and save the plots in there

  • file_format – select the file format to store the plots

Returns:

list of handles to the created figures

plot_states(ro: StepSequence, idcs_sel: Optional[Sequence[int]] = None)[source]

Plot all state trajectories of the given rollout.

Parameters:
  • ro – input rollout

  • idcs_sel – indices of the selected selected states, if None plot all

plot_statistic_across_rollouts(rollouts: Sequence[StepSequence], stat_fcn: callable, stat_fcn_kwargs=None, obs_idcs: Optional[Sequence[int]] = None, act_idcs: Optional[Sequence[int]] = None)[source]

Plot one statistic of interest (e.g. mean) across a list of rollouts.

Parameters:
  • rollouts – list of rollouts, they can be of unequal length but are assumed to be from the same type of env

  • stat_fcn – function to calculate the statistic of interest (e.g. np.mean)

  • stat_fcn_kwargs – keyword arguments for the stat_fcn (e.g. {‘axis’: 0})

  • obs_idcs – indices of the observations to process and plot, pass None to select all

  • act_idcs – indices of the actions to process and plot, pass None to select all

surface

draw_surface(x: ndarray, y: ndarray, z_fcn: Union[Callable[[ndarray], ndarray], Module], x_label: str, y_label: str, z_label: str, data_format='numpy', fig: Optional[Figure] = None, title: Optional[str] = None, plot_kwargs: Optional[dict] = None) Figure[source]

Render a 3-dim surface plot by providing a 1-dim array of x and y points and a function to calculate the z values.

Note

If you want to have a tight layout, it is best to pass axes of a figure with tight_layout=True or constrained_layout=True.

Parameters:
  • x – x-axis 1-dim grid for constructing the 2-dim mesh grid

  • y – y-axis 1-dim grid for constructing the 2-dim mesh grid

  • z_fcn – function that defines the surface, takes a 2-dim vector as input

  • x_label – label for the x-axis

  • y_label – label for the y-axis

  • z_label – label for the z-axis

  • data_format – data format, ‘numpy’ or ‘torch’

  • fig – handle to figure, pass None to create a new figure

  • title – title displayed above the figure, set to None to suppress the title

  • plot_kwargs – keyword arguments forwarded to pyplot’s plot_surface() function

Returns:

handle to figure

utils

class AccNorm(vmin=None, vmax=None, clip=False)[source]

Bases: Normalize

Accumulative normalizer which is useful to have one colormap consistent for multiple images. Adding new data will expand the limits.

Parameters

vmin, vmaxfloat or None

If vmin and/or vmax is not given, they are initialized from the minimum and maximum value, respectively, of the first input processed; i.e., __call__(A) calls autoscale_None(A).

clipbool, default: False

If True values falling outside the range [vmin, vmax], are mapped to 0 or 1, whichever is closer, and masked values are set to 1. If False masked values remain masked.

Clipping silently defeats the purpose of setting the over, under, and masked colors in a colormap, so it is likely to lead to surprises; therefore the default is clip=False.

Notes

Returns 0 if vmin == vmax.

autoscale(A)[source]

Set vmin, vmax to min, max of A.

autoscale_None(A)[source]

If vmin or vmax are not set, use the min/max of A to set them.

draw_sep_cbar(ax_cb: ~typing.Optional[~matplotlib.axes._axes.Axes] = None, cbar_label: ~typing.Optional[str] = None, cbar_orientation: ~typing.Optional[str] = 'vertical', fig_size: ~typing.Optional[tuple] = (8, 6), cmap: ~typing.Optional[~matplotlib.colors.Colormap] = None, norm: ~typing.Optional[~matplotlib.colors.Normalize] = <matplotlib.colors.Normalize object>, num_major_ticks_cb: ~typing.Optional[int] = None)[source]

Add a separate figure with a color bar.

Parameters:
  • ax_cb – axis to draw the color bar onto, if None a new figure is created

  • cbar_label – label for the color bar, if None no label is printed

  • cbar_orientation – orientation to ColorbarBase (vertical of horizontal)

  • fig_size – width and height of the figure in inches

  • cmap – colormap passed to ColorbarBase

  • norm – colormap normalizer passed to ColorbarBase

  • num_major_ticks_cb – number of major axis ticks for the color bar, set None for default behavior

Returns:

handle color bar figure

max_prime_factor(n: int) int[source]

Get the largest prime number that is a factor of the given number

Parameters:

n – given number \(n\)

Returns:

largest prime number \(p\) such that \(p \cdot x = n\)

most_square_product(n: int) Tuple[int, int][source]

Heuristic to get two square-like integers that when multiplied yield the input

Parameters:

n – given number \(n\)

Returns:

lower and higher integer

num_rows_cols_from_length(n: int, transposed: bool = False) Tuple[int, int][source]

Use a heuristic to get the number of rows and columns for a plotting grid, given the total number of plots to draw.

Parameters:
  • n – total number of plots to draw

  • transposed – change number of rows and number of columns

Returns:

number of rows and columns

Module contents

set_style(style_name: str = 'default')[source]

Sets colors, fonts, font sizes, bounding boxes, and more for plots using pyplot.

Note

The font sizes of the predefined styles will be overwritten!

Parameters:

style_name – str containing the matplotlib style name, or ‘default’ for the Pyrado default style