Source code for nustattools.plotting._hinton

"""Plot matrices in a colourblind-friendly way."""

from __future__ import annotations

from typing import Any

import numpy as np
from matplotlib import collections, colormaps, patches, ticker
from matplotlib import pyplot as plt
from numpy.typing import NDArray


[docs] def hinton( matrix: NDArray[Any], *, vmax: float | None = None, shape: str = "circle", origin: str = "upper", cmap: str = "cividis", legend: bool = False, ax: None | Any = None, ) -> tuple[Any, Any]: """Draw Hinton diagram for visualizing a matrix with positive and negative values. Parameters ---------- matrix : numpy.ndarray The matrix to be visualized. vmax : float, optional The upper limit of the value scale. `-vmax` will be used as the lower limit. Defaults to being inferred from the data. shape : str, default="circle" Either "circle" or "square". The shape of the symbols representing the matrix elements. origin : str, default="upper" Either "upper" or "lower". Where to put the 1st element of the 1st axis. cmap : str, default="cividis" The Matplotlib colormap to take the colors from. Should be perceptually uniform sequantial. legend : bool, default=False Draw a "legend" to the side of the plot, showing the range of values. ax : matplotlib.axes.Axes, optional Axes object to plot onto Returns ------- col0, col1 : matplotlib.collections.PatchCollection The collections of patches for the negative and positive colors respectively Examples -------- .. plot:: :include-source: True Basic usage: >>> import numpy as np >>> from matplotlib import pyplot as plt >>> from nustattools import plotting as nuplt >>> rng = np.random.default_rng() >>> M = rng.uniform(size=(10,10)) - 0.5 >>> nuplt.hinton(M) .. plot:: :include-source: True Plot with a legend: >>> import numpy as np >>> from matplotlib import pyplot as plt >>> from nustattools import plotting as nuplt >>> rng = np.random.default_rng() >>> M = rng.uniform(size=(10,10)) - 0.5 >>> nuplt.hinton(M, legend=True) >>> plt.tight_layout(pad=2) .. plot:: :include-source: True Variants: >>> import numpy as np >>> from matplotlib import pyplot as plt >>> from nustattools import plotting as nuplt >>> rng = np.random.default_rng() >>> M = rng.uniform(size=(10,10)) - 0.5 >>> nuplt.hinton(M, legend=True, shape="square", cmap="gray", origin="lower") >>> plt.tight_layout(pad=2) Notes ----- Based on https://matplotlib.org/stable/gallery/specialty_plots/hinton_demo.html """ ax = ax if ax is not None else plt.gca() if vmax is None: vmax = np.abs(matrix).max() colors = colormaps[cmap]((0, 0.5, 1)) # Create the shapes patches0 = [] patches1 = [] def add_patch(x: float, y: float, w: float) -> None: size = np.sqrt(abs(w) / vmax) if shape == "circle": patch: patches.Patch = patches.Circle( (x, y), size / 2, ) elif shape == "square": patch = patches.Rectangle( (x - size / 2, y - size / 2), size, size, ) else: e = f"Unknown shape: {shape}" raise ValueError(e) if w > 0: patches1.append(patch) else: patches0.append(patch) for (y, x), w in np.ndenumerate(matrix): add_patch(x, y, w) # Add "legend" if legend: lw = int(np.ceil(np.max(matrix.shape) / 20)) lx = matrix.shape[1] + lw lh = int(np.ceil(matrix.shape[0] * 0.45) * 2 - 1) ly = ((matrix.shape[0] - 1) / 2) - ((lh - 1) / 2) for i, w in enumerate(np.linspace(-vmax, vmax, lh)): y = ly + (lh - 1) * (w / vmax + 1) / 2 ww = w if origin == "lower" else -w for x in range(lx, lx + lw): add_patch(x, y, ww) if i in (0, (lh - 1) / 2, lh - 1): ax.plot( [x + 0.6, x + 0.6 + 0.2 * lw], [y, y], clip_on=False, color="k", linewidth=1, ) if abs(ww) < 1e-12: ww = 0 ax.text(x + 0.6 + 0.2 * lw, y, f" {ww:.3g}", verticalalignment="center") ax.add_patch( patches.Rectangle( (lx - 0.6, ly - 0.6), lw + 0.2, lh + 0.2, facecolor=colors[1], edgecolor="k", clip_on=False, ) ) # Create collections and set colors, plots much faster than adding every single patch to the axes col0 = collections.PatchCollection(patches0) col0.set_facecolor(colors[0]) col0.set_clip_on(False) col1 = collections.PatchCollection(patches1) col1.set_facecolor(colors[2]) col1.set_clip_on(False) # Add collections to axes ax.add_collection(col0) ax.add_collection(col1) ax.patch.set_facecolor(colors[1]) ax.set_aspect("equal", "box") edge = 0.1 ax.set_xlim(-0.5 - edge, matrix.shape[1] - 0.5 + edge) ax.set_ylim(-0.5 - edge, matrix.shape[0] - 0.5 + edge) m: int = np.max(matrix.shape) ax.xaxis.set_major_locator( ticker.MaxNLocator( integer=True, min_n_ticks=1, nbins=1 + int(8 * matrix.shape[1] / m) ) ) ax.yaxis.set_major_locator( ticker.MaxNLocator( integer=True, min_n_ticks=1, nbins=1 + int(8 * matrix.shape[0] / m) ) ) if origin == "upper": ax.invert_yaxis() ax.xaxis.tick_top() elif origin == "lower": pass else: e = f"Unknown origin: {origin}" raise ValueError(e) return (col0, col1)
__all__ = ["hinton"]