Source code for eda_report.plotting

from io import BytesIO
from multiprocessing import get_context
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union

import matplotlib as mpl
import numpy as np
from matplotlib.axes import Axes
from matplotlib.colors import to_rgb
from matplotlib.figure import Figure
from scipy.stats import gaussian_kde, probplot
from tqdm import tqdm

from eda_report._validate import _validate_dataset, _validate_univariate_input
from eda_report.bivariate import Dataset

# Matplotlib configuration
GENERAL_RC_PARAMS = {
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.titlesize": 12,
    "axes.titleweight": 500,
    "figure.autolayout": True,
    "figure.figsize": (5.6, 3.5),
    "font.family": "serif",
    "savefig.dpi": 120,
}
# Customize box-plots
BOXPLOT_RC_PARAMS = {
    **GENERAL_RC_PARAMS,
    "boxplot.medianprops.color": "black",
    "boxplot.patchartist": True,
    "boxplot.vertical": False,
}
# Customize correlation-plots
CORRPLOT_RC_PARAMS = {**GENERAL_RC_PARAMS, "figure.figsize": (7, 6.3)}
# Customize regression-plots
REGPLOT_RC_PARAMS = {**GENERAL_RC_PARAMS, "figure.figsize": (5.2, 5)}


@mpl.rc_context(GENERAL_RC_PARAMS)
def _savefig(figure: Figure) -> BytesIO:
    """Saves the contents of a :class:`~matplotlib.figure.Figure` in PNG
    format, as bytes in a file-like object. This allows rapid in-memory
    access when compiling the report.

    Args:
        figure (matplotlib.figure.Figure): Graph content.

    Returns:
        io.BytesIO: A graph in PNG format as bytes.
    """
    graph = BytesIO()
    figure.savefig(graph, format="png")
    return graph


def _get_or_validate_axes(ax: Axes = None) -> Axes:
    """Create or validate an Axes instance.

    Args:
        ax (matplotlib.axes.Axes, optional): Axes instance. Defaults to None.

    Raises:
        TypeError: If `ax` is not an Axes instance.

    Returns:
        Axes: Axes instance.
    """
    if ax is None:
        return Figure().subplots()
    elif isinstance(ax, Axes):
        return ax
    else:
        raise TypeError(f"Invalid input for 'ax': {type(ax)}")


def _get_color_shades_of(color: str, num: int = None) -> Sequence:
    """Obtain an array with `num` shades of the specified `color`.

    Args:
        color (str): The desired color.
        num (int): Desired number of color shades.

    Returns:
        Sequence: Array of RGB colors.
    """
    color_rgb = to_rgb(color)
    return np.linspace(color_rgb, (0.25, 0.25, 0.25), num=num)


[docs] @mpl.rc_context(BOXPLOT_RC_PARAMS) def box_plot( data: Iterable, *, label: str, hue: Iterable = None, color: Union[str, Sequence] = None, ax: Axes = None, ) -> Axes: """Get a box-plot from numeric values. Args: data (Iterable): Values to plot. label (str): A name for the ``data``, shown in the title. hue (Iterable, optional): Values for grouping the ``data``. Defaults to None. color (Union[str, Sequence]): A valid matplotlib color specifier. ax (matplotlib.axes.Axes, optional): Axes instance. Defaults to None. Returns: matplotlib.axes.Axes: Matplotlib axes with the box-plot. """ original_data = _validate_univariate_input(data) data = original_data.dropna() ax = _get_or_validate_axes(ax) if hue is None: bxplot = ax.boxplot( data, tick_labels=[label], sym=".", boxprops=dict(facecolor=color, alpha=0.75), ) ax.set_yticklabels("") else: hue = _validate_univariate_input(hue)[original_data.notna()] groups = {key: sub_series for key, sub_series in data.groupby(hue)} bxplot = ax.boxplot( groups.values(), tick_labels=groups.keys(), sym="." ) if color is None: colors = [f"C{idx}" for idx in range(hue.nunique())] else: colors = _get_color_shades_of(color, hue.nunique()) for patch, color in zip(bxplot["boxes"], reversed(colors)): patch.set_facecolor(color) patch.set_alpha(0.75) if hue.name is not None: ax.set_ylabel(f"{hue.name}".title()) ax.set_title(f"Box-plot of {label}") return ax
[docs] @mpl.rc_context(GENERAL_RC_PARAMS) def kde_plot( data: Iterable, *, label: str, hue: Iterable = None, color: Union[str, Sequence] = None, ax: Axes = None, ) -> Axes: """Get a kde-plot from numeric values. Args: data (Iterable): Values to plot. label (str): A name for the ``data``, shown in the title. hue (Iterable, optional): Values for grouping the ``data``. Defaults to None. color (Union[str, Sequence]): A valid matplotlib color specifier. ax (matplotlib.axes.Axes, optional): Axes instance. Defaults to None. Returns: matplotlib.axes.Axes: Matplotlib axes with the kde-plot. """ original_data = _validate_univariate_input(data) data = original_data.dropna() ax = _get_or_validate_axes(ax) if len(data) < 2 or np.isclose(data.std(), 0): msg = "[Could not plot kernel density estimate.\n Data is singular.]" ax.text(x=0.08, y=0.45, s=msg, color="#f72", size=14, weight=600) return ax eval_points = np.linspace(data.min(), data.max(), num=len(data)) if hue is None: kernel = gaussian_kde(data) density = kernel(eval_points) ax.plot(eval_points, density, label=label, color=color) ax.fill_between(eval_points, density, alpha=0.3, color=color) else: hue = _validate_univariate_input(hue)[original_data.notna()] if color is None: colors = [f"C{idx}" for idx in range(hue.nunique())] else: colors = _get_color_shades_of(color, hue.nunique()) for color, (key, series) in zip(colors, data.groupby(hue)): kernel = gaussian_kde(series) density = kernel(eval_points) ax.plot(eval_points, density, label=key, alpha=0.75, color=color) ax.fill_between(eval_points, density, alpha=0.25, color=color) if hue.name is not None: ax.legend(title=f"{hue.name}".title()) ax.set_xlabel(label) ax.set_ylim(0) ax.set_title(f"Density plot of {label}") return ax
[docs] @mpl.rc_context(REGPLOT_RC_PARAMS) def prob_plot( data: Iterable, *, label: str, marker_color: Union[str, Sequence] = "C0", line_color: Union[str, Sequence] = "#222", ax: Axes = None, ) -> Axes: """Get a probability-plot from numeric values. Args: data (Iterable): Values to plot. label (str): A name for the ``data``, shown in the title. marker_color (Union[str, Sequence]): Color for the plotted points. Defaults to "C0". line_color (Union[str, Sequence]): Color for the line of best fit. Defaults to "#222". ax (matplotlib.axes.Axes, optional): Axes instance. Defaults to None. Returns: matplotlib.axes.Axes: Matplotlib axes with the probability-plot. """ original_data = _validate_univariate_input(data) data = original_data.dropna() ax = _get_or_validate_axes(ax) probplot(data, fit=True, plot=ax) ax.lines[0].set_color(marker_color) ax.lines[1].set_color(line_color) ax.set_xlabel("Theoretical Quantiles (Normal)") ax.set_title(f"Probability plot of {label}") return ax
[docs] @mpl.rc_context(GENERAL_RC_PARAMS) def bar_plot( data: Iterable, *, label: str, color: Union[str, Sequence] = None, ax: Axes = None, ) -> Axes: """Get a bar-plot from a sequence of values. Args: data (Iterable): Values to plot. label (str): A name for the ``data``, shown in the title. color (Union[str, Sequence]): A valid matplotlib color specifier. ax (matplotlib.axes.Axes, optional): Axes instance. Defaults to None. Returns: matplotlib.axes.Axes: Matplotlib axes with the bar-plot. """ original_data = _validate_univariate_input(data) data = original_data.dropna() ax = _get_or_validate_axes(ax) # Include no more than 10 of the most common values top_10 = data.value_counts().nlargest(10) bars = ax.bar(top_10.index.map(str), top_10, alpha=0.8, color=color) ax.bar_label(bars, labels=[f"{x:,.0f}" for x in top_10], padding=2) if (num_unique := data.nunique()) > 10: title = f"Bar-plot of {label} (Top 10 of {num_unique})" else: title = f"Bar-plot of {label}" ax.set_title(title) ax.set_ylabel("Count") ax.tick_params(axis="x", rotation=90) # Improve visibility for long labels return ax
def _plot_variable(variable_data_hue_and_color: Tuple) -> Tuple: """Helper function to concurrently plot variables in a multiprocessing Pool. Args: variable_data_hue_and_color (Tuple): A variable, plot data, hue data and the desired plot color. Returns: Tuple: `variable`s name, and graphs in a dict. """ variable, data, hue, color = variable_data_hue_and_color if variable.var_type == "numeric": plots = { "box_plot": box_plot( data=data, hue=hue, label=variable.name, color=color ), "kde_plot": kde_plot( data=data, hue=hue, label=variable.name, color=color ), "prob_plot": prob_plot( data, label=variable.name, marker_color=color ), } else: # {"boolean", "categorical", "datetime", "numeric (<=10 levels)"} plots = {"bar_plot": bar_plot(data, label=variable.name, color=color)} graph_images = {name: _savefig(ax.figure) for name, ax in plots.items()} return variable.name, graph_images
[docs] @mpl.rc_context(CORRPLOT_RC_PARAMS) def plot_correlation( variables: Iterable, max_pairs: int = 20, color_pos: Union[str, Sequence] = "orangered", color_neg: Union[str, Sequence] = "steelblue", ax: Axes = None, ) -> Axes: """Create a bar chart showing the top ``max_pairs`` most correlated variables. Bars are annotated with variable pairs and their respective Pearson correlation coefficients. Args: variables (Iterable): 2-dimensional numeric data. max_pairs (int): The maximum number of numeric pairs to include in the plot. Defaults to 20. color_pos (Union[str, Sequence]): Color for positive correlation bars. Defaults to "orangered". color_neg (Union[str, Sequence]): Color for negative correlation bars. Defaults to "steelblue". ax (matplotlib.axes.Axes, optional): Axes instance. Defaults to None. Returns: matplotlib.axes.Axes: A bar-plot of correlation data. """ if not isinstance(variables, Dataset): variables = Dataset(variables) if variables._correlation_values is None: return None # Show at most `max_pairs` numeric pairs. pairs_to_show = variables._correlation_values[:max_pairs] # Reverse items so largest values appear at the top. corr_data = dict(reversed(pairs_to_show)) labels = [" vs ".join(pair) for pair in corr_data.keys()] ax = _get_or_validate_axes(ax) ax.barh(labels, corr_data.values(), edgecolor="#222", linewidth=0.5) ax.set_xlim(-1.1, 1.1) ax.spines["left"].set_position("zero") # Place y-axis spine at x=0 ax.yaxis.set_visible(False) # Hide y-axis labels for p, label in zip(ax.patches, labels): p.set_alpha(min(1, abs(p.get_width()) + 0.1)) if p.get_width() < 0: p.set_facecolor(color_neg) ax.text( p.get_x(), p.get_y() + p.get_height() / 2, f"{p.get_width():,.2f} ({label}) ", size=8, ha="right", va="center", ) else: p.set_facecolor(color_pos) ax.text( p.get_x(), p.get_y() + p.get_height() / 2, f" {p.get_width():,.2} ({label})", size=8, ha="left", va="center", ) ax.set_title(f"Pearson Correlation (Top {len(corr_data)})") return ax
[docs] @mpl.rc_context(REGPLOT_RC_PARAMS) def regression_plot( x: Iterable, y: Iterable, labels: Tuple[str, str], marker_color: Union[str, Sequence] = "C0", line_color: Union[str, Sequence] = "#444", ax: Axes = None, ) -> Axes: """Get a regression-plot from the provided pair of numeric values. Args: x (Iterable): Numeric values. y (Iterable): Numeric values. labels (Tuple[str, str]): Names for `x` and `y` respectively, shown in axis labels. marker_color (Union[str, Sequence]): Color for the plotted points. Defaults to "C0". line_color (Union[str, Sequence]): Color for the line of best fit. Defaults to "#444". ax (matplotlib.axes.Axes, optional): Axes instance. Defaults to None. Returns: matplotlib.axes.Axes: Matplotlib axes with the regression-plot. """ var1, var2 = labels data = _validate_dataset({var1: x, var2: y}).dropna() if len(data) > 50000: data = data.sample(50000) ax = _get_or_validate_axes(ax) x = data[var1] y = data[var2] slope, intercept = np.polyfit(x, y, deg=1) ax.scatter(x, y, s=40, alpha=0.7, color=marker_color, edgecolors="#444") reg_line_x = np.linspace(x.min(), x.max(), num=20) reg_line_y = slope * reg_line_x + intercept ax.plot(reg_line_x, reg_line_y, color=line_color, lw=2) ax.set_title( f"Slope: {slope:,.4f}\nIntercept: {intercept:,.4f}\n" + f"Correlation: {x.corr(y):.4f}", size=11, ) ax.set_xlabel(var1) ax.set_ylabel(var2) return ax
def _plot_regression(data_and_color: Tuple) -> Tuple: """Helper function to plot regression-plots concurrently. Args: data_and_color (Tuple): Dataframe, and desired marker-color. Returns: Tuple: Names for the variable pair, and axes with the regression plot. """ data, color = data_and_color var1, var2 = data.columns ax = regression_plot( x=data[var1], y=data[var2], labels=(var1, var2), marker_color=color ) return (var1, var2), ax def _plot_dataset(variables: Dataset, color: str = None) -> Optional[Dict]: """Concurrently plot regression-plots in a multiprocessing Pool. Args: variables (Dataset): Bi-variate analysis results. color (str, optional): The color to apply to the graphs. Returns: Optional[Dict]: A dictionary with a correlation plot and regression plots. """ if variables._correlation_values is None: return None else: # Take the top 20 pairs by magnitude of correlation. # 20 var_pairs ≈ 10+ pages in report document # 20 numeric columns == 190 var_pairs ≈ 95+ pages. pairs_to_include = [ pair for pair, _ in variables._correlation_values[:20] ] with get_context("spawn").Pool() as p: paired_data = [ (variables.data.loc[:, pair], color) for pair in pairs_to_include ] bivariate_regression_plots = dict( tqdm( # Plot in parallel processes p.imap(_plot_regression, paired_data), # Progress-bar options total=len(pairs_to_include), bar_format=( "{desc} {percentage:3.0f}%|{bar:35}| " "{n_fmt}/{total_fmt} pairs." ), desc="Bivariate analysis:", dynamic_ncols=True, ) ) return { "correlation_plot": _savefig(plot_correlation(variables).figure), "regression_plots": { var_pair: _savefig(plot.figure) for var_pair, plot in bivariate_regression_plots.items() }, }