From 0d42713b3be0cd4df4c2aeff1a993960e6c4740c Mon Sep 17 00:00:00 2001 From: Vincent Donnefort Date: Wed, 27 Jan 2021 17:28:14 +0000 Subject: [PATCH] lisa.stats: Add plotly support for plot_stats() Plotly is a data visualization framework which offers a more user friendly interface, compared to Matplotlib. This patch allows to use this framework with the Stats plot_stats() function. The later still defaults to Matplotlib. Plotly can be enabled by passing backend='plotly'. --- lisa/stats.py | 204 ++++++++++++++++++++++++++++++++++++----------- setup.py | 2 + shell/lisa_shell | 3 +- 3 files changed, 162 insertions(+), 47 deletions(-) diff --git a/lisa/stats.py b/lisa/stats.py index 013734fca..ef788559e 100644 --- a/lisa/stats.py +++ b/lisa/stats.py @@ -22,6 +22,9 @@ from itertools import combinations import scipy.stats import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots as plotly_make_subplots import numpy as np from lisa.utils import Loggable, memoized, FrozenDict, deduplicate, fold @@ -806,9 +809,22 @@ class Stats(Loggable): df[val_col] = df[val_col].round(10) return df - def _plot(self, df, title, plot_func, facet_rows, facet_cols, collapse_cols, filename=None, interactive=None): + def _plot(self, + df, + title, + plot_func, + facet_rows, + facet_cols, + collapse_cols, + tick_group=None, + filename=None, + interactive=None, + backend='matplotlib'): unit_col = self._unit_col + if tick_group and backend == 'matplotlib': + raise ValueError(f'tick_group not supported with matplotlib') + group_on = list(facet_rows) + list(facet_cols) facet_rows_len = len(facet_rows) def split_row_col(group): @@ -843,48 +859,89 @@ class Stats(Loggable): collapsed_col = None collapse_group = {} - figure, axes = make_figure( - width=16, - height=16, - nrows=nrows, - ncols=ncols, - interactive=interactive, - ) - if nrows == 1 and ncols == 1: - axes = [[axes]] - elif nrows == 1: - axes = [axes] - elif ncols == 1: - axes = [[ax] for ax in axes] - - figure.set_tight_layout(dict( - h_pad=3.5, - )) - figure.suptitle(title, y=1.01, fontsize=30) - - for group, subdf in grouped: - group = fixup_tuple(group) - - row, col = split_row_col(group) - ax = axes[rows.index(row)][cols.index(col)] - - if subdf.empty: - figure.delaxes(ax) - else: + def _plot_plotly(): + figure = plotly_make_subplots( + rows=nrows, cols=ncols, + subplot_titles=[ + '-'.join(group) if isinstance(group, tuple) else group + for group, subdf in grouped + ] + ) + + showlegend=True + for group, subdf in grouped: + group = fixup_tuple(group) + row, col = split_row_col(group) + group_dict = dict(zip(group_on, group)) + x = rows.index(row) + 1 + y = cols.index(col) + 1 subdf = subdf.drop(columns=group_on) subdf = self._collapse_cols(subdf, collapse_group) group_dict = dict(zip(group_on, group)) - plot_func(subdf, ax, collapsed_col, group_dict) - - if filename: - # The suptitle is not taken into account by tight layout by default: - # https://stackoverflow.com/questions/48917631/matplotlib-how-to-return-figure-suptitle - suptitle = figure._suptitle - figure.savefig(filename, bbox_extra_artists=[suptitle], bbox_inches='tight') + plot_func(df=subdf, ax=(figure, x, y), + collapsed_col=collapsed_col, group=group_dict, + tick_group=tick_group, showlegend=showlegend) + showlegend=False + + if filename: + if not filename.endswith('.html'): + raise ValueError(f'backend=plotly only supports HTML export') + figure.write_html(filename) + + return figure + + def _plot_matplotlib(): + figure, axes = make_figure( + width=16, + height=16, + nrows=nrows, + ncols=ncols, + interactive=interactive, + ) + if nrows == 1 and ncols == 1: + axes = [[axes]] + elif nrows == 1: + axes = [axes] + elif ncols == 1: + axes = [[ax] for ax in axes] + + figure.set_tight_layout(dict( + h_pad=3.5, + )) + figure.suptitle(title, y=1.01, fontsize=30) + + for group, subdf in grouped: + group = fixup_tuple(group) + + row, col = split_row_col(group) + ax = axes[rows.index(row)][cols.index(col)] + + if subdf.empty: + figure.delaxes(ax) + else: + subdf = subdf.drop(columns=group_on) + subdf = self._collapse_cols(subdf, collapse_group) + group_dict = dict(zip(group_on, group)) + plot_func(df=subdf, ax=ax, collapsed_col=collapsed_col, + group=group_dict) + + if filename: + # The suptitle is not taken into account by tight layout by default: + # https://stackoverflow.com/questions/48917631/matplotlib-how-to-return-figure-suptitle + suptitle = figure._suptitle + figure.savefig(filename, bbox_extra_artists=[suptitle], bbox_inches='tight') + + return figure + + if backend == 'plotly': + return _plot_plotly() + else: + return _plot_matplotlib() - return figure - def plot_stats(self, filename=None, remove_ref=None, interactive=None, groups_as_row=True, kind=None, **kwargs): + def plot_stats(self, filename=None, remove_ref=None, interactive=None, + groups_as_row=True, kind=None, backend='matplotlib', + **kwargs): """ Returns a :class:`matplotlib.figure.Figure` containing the statistics for the class input :class:`pandas.DataFrame`. @@ -923,6 +980,9 @@ class Stats(Loggable): **kwargs ) + if kind not in ['horizontal_bar', 'vertical_bar']: + raise ValueError(f'Unsupported plot kind: {kind}') + mean_suffix = ' (confidence level: {:.1f}%)'.format( self._mean_ci_confidence * 100 ) @@ -937,7 +997,7 @@ class Stats(Loggable): f' compared against: {pretty_ref_group}' if self._compare else '' ) - def plot(df, ax, collapsed_col, group): + def plot(*, df, ax, collapsed_col, group, tick_group=None, showlegend=True): try: error = [ df[col] @@ -961,12 +1021,7 @@ class Stats(Loggable): y_col = self._val_col - if kind == 'horizontal_bar': - plot = df.plot.barh - elif kind == 'vertical_bar': - plot = df.plot.bar - else: - raise ValueError(f'Unsupported plot kind: {kind}') + plot = df.plot.barh if kind == 'horizontal_bar' else df.plot.bar plot( ax=ax, @@ -1044,6 +1099,49 @@ class Stats(Loggable): textcoords='offset points', ) + def plot_plotly(*, df, ax, collapsed_col, group, tick_group, showlegend): + df = df.copy() + fig, facet_row, facet_col = ax + df['text'] = df[self._val_col].round(2).astype(str) + ' ' + df[self._unit_col] + + error_y = None if df[self._ci_cols[1]].isna().any() else dict( + array=df[self._ci_cols[1]], + arrayminus=df[self._ci_cols[0]] + ) + + bar_colors = iter(px.colors.qualitative.Plotly); + + if tick_group: + grouped_tick = df.groupby([tick_group], observed=True) + else: + grouped_tick = [('-'.join(group.values()), df)] + + for subtick, subdf in grouped_tick: + x = subdf[collapsed_col] if collapsed_col else None + y = subdf[self._val_col] + + if kind == 'horizontal_bar': + x, y = y, x + + fig.add_bar(x=x, y=y, + text=subdf['text'], textposition='auto', + error_y=error_y, + name=subtick, + showlegend=showlegend, + legendgroup=subtick, + marker_color=next(bar_colors), + orientation='h' if kind == 'horizontal_bar' else 'v', + row=facet_row, col=facet_col) + + update_units = fig.update_xaxes if kind == 'horizontal_bar' else fig.update_yaxes + update_units(title_text=df[self._unit_col].iloc[0], + row=facet_row, col=facet_col) + + if kind == 'horizontal_bar': + fig.update_yaxes(tickangle=-90, row=facet_row, col=facet_col) + + fig.update_layout(barmode='group', title=title) + # Subplot matrix: # * one line per sub-group (e.g. metric) # * one column per stat @@ -1060,15 +1158,29 @@ class Stats(Loggable): if groups_as_row: facet_rows, collapse_cols = collapse_cols, facet_rows + # plotly backend supports another grouping layer, before having the + # need for collapsing columns. + tick_group = None + if backend == 'plotly' and collapse_cols: + collapse_cols_nunique = { + col: df[col].nunique() + for col in collapse_cols + } + tick_group = max(collapse_cols_nunique, + key=collapse_cols_nunique.get) + collapse_cols.remove(tick_group) + return self._plot( df, title=title, - plot_func=plot, + plot_func=plot if backend == 'matplotlib' else plot_plotly, facet_rows=facet_rows, facet_cols=facet_cols, + tick_group=tick_group, collapse_cols=collapse_cols, filename=filename, interactive=interactive, + backend=backend, ) @staticmethod diff --git a/setup.py b/setup.py index 858e9424f..0d6dd085c 100755 --- a/setup.py +++ b/setup.py @@ -110,6 +110,8 @@ setup( # Depdendencies that are shipped as part of the LISA repo as # subtree/submodule "devlib", + + "plotly" ], extras_require=extras_require, diff --git a/shell/lisa_shell b/shell/lisa_shell index eae3d1f32..010efe58b 100755 --- a/shell/lisa_shell +++ b/shell/lisa_shell @@ -193,7 +193,8 @@ function _lisa-install-nbextensions { echo "Installing jupyter lab extensions..." jupyter labextension install --minimize=False \ @jupyter-widgets/jupyterlab-manager \ - jupyter-matplotlib + jupyter-matplotlib \ + jupyterlab-plotly return fi -- GitLab