diff --git a/lisa/datautils.py b/lisa/datautils.py index 9dd6d163a867b60c131c21b7d753c0ae9e5f1c8f..c0f268b27dcd653e484a9984543050fd6e024836 100644 --- a/lisa/datautils.py +++ b/lisa/datautils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # -# Copyright (C) 2019, Arm Limited and contributors. +# Copyright (C) 2025, Arm Limited and contributors. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. @@ -564,7 +564,15 @@ def df_split_signals(df, signal_cols, align_start=False, window=None): else: _signal_cols = signal_cols - for group, signal in df.groupby(_signal_cols, observed=True, sort=False, group_keys=False): + # HACK + grouped = df + try: + grouped = grouped.group_by(_signal_cols) + except Exception as e: + grouped = df.drop('iteration') + grouped = grouped.group_by(_signal_cols) + + for group, signal in grouped: # When only one column is looked at, the group is the value instead of # a tuple of values if isinstance(group, tuple) : @@ -1369,14 +1377,14 @@ def df_make_empty_clone(df): Make an empty clone of the given dataframe. :param df: The template dataframe. - :type df: pandas.DataFrame + :type df: polars.DataFrame More specifically, the following aspects are cloned: * Column names * Column dtypes """ - return df.iloc[0:0].copy(deep=True) + return df.slice(0, 0).clone() @DataFrameAccessor.register_accessor @@ -2676,20 +2684,19 @@ def df_find_redundant_cols(df, col, cols=None): are used. :type cols: str or None """ - grouped = df.groupby(col, observed=True, group_keys=False) + grouped = df.group_by(col) cols = cols or (set(df.columns) - {col}) return { _col: dict(map( lambda x: (x[0], x[1][0]), - series.items() + series )) for _col, series in ( ( _col, - grouped[_col].unique() + grouped.n_unique().collect().to_series() ) for _col in cols - if (grouped[_col].nunique() == 1).all() ) } diff --git a/lisa/stats.py b/lisa/stats.py index 7ac3f020878b0e3b5e5bc17c01050b43030b1bad..35931ec329194787e02e7191bd26464afb6e640e 100644 --- a/lisa/stats.py +++ b/lisa/stats.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # -# Copyright (C) 2020, Arm Limited and contributors. +# Copyright (C) 2025, Arm Limited and contributors. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. @@ -29,6 +29,7 @@ from collections import OrderedDict import warnings import scipy.stats +import polars as pl import pandas as pd import numpy as np import holoviews as hv @@ -337,7 +338,7 @@ class Stats(Loggable): if not unit.normalizable }, ): - if df.empty: + if df.limit(1).collect().is_empty(): raise ValueError('Empty dataframes are not handled') if filter_rows: @@ -350,7 +351,7 @@ class Stats(Loggable): tweak_cols = {mean_kind_col, control_var_col} tag_cols = sorted( - (set(df.columns) - {value_col, *ci_cols} - tweak_cols) | {unit_col} + (set(df.collect_schema()) - {value_col, *ci_cols} - tweak_cols) | {unit_col} ) # Find tag columns that are 100% correlated to ref_group keys, and add @@ -381,7 +382,7 @@ class Stats(Loggable): # need to get rid of them for col1, col2 in combinations(tag_cols.copy(), 2): try: - if (df[col1] == df[col2]).all(): + if df.select(col1).collect().equals(df.select(col2).collect()): if col1 not in ref_group: to_remove = col1 elif col2 not in ref_group: @@ -406,16 +407,16 @@ class Stats(Loggable): # Check that tags are sufficient to describe the data, so that we don't # end up with 2 different values for the same set of tags - duplicated_tags_size = df.groupby(tag_cols, observed=True, group_keys=False).size() - duplicated_tags_size = duplicated_tags_size[duplicated_tags_size > 1] - if not duplicated_tags_size.empty: + duplicated_tags_size = df.group_by(tag_cols).len(name='n') + duplicated_tags_size = duplicated_tags_size.filter(pl.col('n') > 1) + if not duplicated_tags_size.limit(1).collect().is_empty(): raise ValueError(f'Same tags applied to more than one value, another tag column is needed to distinguish them:\n{duplicated_tags_size}') if agg_cols: pass # Default to "iteration" if there was no ref group nor columns to # aggregate over - elif 'iteration' in df.columns: + elif 'iteration' in df.collect_schema(): agg_cols = ['iteration'] # Aggregate over all tags that are not part of the ref group, since the # ref group keys are the tags that will remain after aggregation @@ -472,18 +473,18 @@ class Stats(Loggable): """ Restrict the given list of columns to columns actually available in df. """ - return sorted(set(cols) & set(df.columns)) + return sorted(set(cols) & set(df.collect_schema())) def _df_remove_tweak_cols(self, df): for col in self._tweak_cols: - with contextlib.suppress(KeyError): - df = df.drop(columns=col) + with contextlib.suppress(pl.exceptions.ColumnNotFoundError): + df = df.drop(col) return df def _df_format(self, df): tag_cols = self._restrict_cols(self._stat_tag_cols, df) # Group together lines for each given tag - df = df.sort_values(by=tag_cols, ignore_index=True) + df = df.sort(by=tag_cols) # Reorder columns cols = deduplicate( @@ -519,10 +520,10 @@ class Stats(Loggable): * One with values being the former column name identifying the value * One with values being the values of the former column """ - return pd.melt(df, - id_vars=self._restrict_cols(self._stat_tag_cols, df), + return df.unpivot( + index=self._restrict_cols(self._stat_tag_cols, df), + variable_name=self._stat_col, value_name=self._val_col, - var_name=self._stat_col, **kwargs ) @@ -573,7 +574,7 @@ class Stats(Loggable): # removed, since they are useless because they have a constant value def remove_cols(df): to_remove = group.keys() - df = df.drop(columns=self._restrict_cols(to_remove, df)) + df = df.drop(self._restrict_cols(to_remove, df)) try: drop_level = df.index.droplevel except AttributeError: @@ -595,23 +596,29 @@ class Stats(Loggable): to_assign = group.keys() - set( col for col in df.columns - if not df[col].isna().all() + if not df[col].is_null().all() ) - df = df.assign(**{ - col: val - for col, val in group.items() - if col in to_assign - }) - - # Drop RangeIndex to avoid getting an "index" column that is - # useless - drop_index = isinstance(df.index, pd.RangeIndex) - df.reset_index(drop=drop_index, inplace=True) + + for col, val in group.items(): + if col in to_assign: + df = df.with_columns(pl.lit(val).alias(col)) + + try: + df = df.drop(index_cols_str) + except pl.exceptions.ColumnNotFoundError: + pass + return df + + # HACK + if index_cols == ['iteration']: + df = df.drop('iteration') + index_cols_str = ', '.join(index_cols) + # Groups as asked by the user comparison_groups = { - FrozenDict(group): df.set_index(index_cols) + FrozenDict(group): df.with_row_index(index_cols_str) for group, df in df_split_signals(df, ref_group.keys()) } @@ -639,11 +646,11 @@ class Stats(Loggable): ] dfs = [df for df in dfs if df is not None] if dfs: - df = pd.concat(dfs, ignore_index=True, copy=False) + df = pl.concat(dfs) if melt: df = self._melt(df) else: - df = pd.DataFrame() + df = pl.DataFrame() return df @@ -695,7 +702,7 @@ class Stats(Loggable): Compute the mean and associated stats """ def get_const_col(group, df, col): - vals = df[col].unique() + vals = df[col].unique().to_list() if len(vals) > 1: raise ValueError(f"Column \"{col}\" has more than one value ({', '.join(vals)}) for the group: {group}") return vals[0] @@ -703,14 +710,14 @@ class Stats(Loggable): def mean_func(ref, df, group): # pylint: disable=unused-argument try: mean_kind = get_const_col(group, df, self._mean_kind_col) - except KeyError: + except pl.exceptions.ColumnNotFoundError: try: unit = get_const_col(group, df, self._unit_col) - except KeyError: + except pl.exceptions.ColumnNotFoundError: unit = None try: control_var = get_const_col(group, df, self._control_var_col) - except KeyError: + except pl.exceptions.ColumnNotFoundError: control_var = None mean_kind = guess_mean_kind(unit, control_var) @@ -746,17 +753,13 @@ class Stats(Loggable): ) if stat in provide_stats ] - return pd.DataFrame.from_records( - rows, - columns=( - self._stat_col, - self._val_col, - self._ci_cols[0], - self._ci_cols[1] - ) - ) - return self._df_group_apply(df, mean_func, index_cols=self._agg_cols) + columns_new = [self._stat_col, self._val_col, self._ci_cols[0], self._ci_cols[1]] + result = {columns_new[i]: [row[i] for row in rows] for i in range(len(columns_new))} + + return pl.DataFrame(result) + + return self._df_group_apply(df.collect(), mean_func, index_cols=self._agg_cols) def _df_stats(self): """ @@ -778,7 +781,8 @@ class Stats(Loggable): stats.pop(stat) else: df_mean = df_make_empty_clone(df) - df_mean.drop(columns=self._agg_cols, inplace=True) + df_mean.drop(self._agg_cols) + df_mean = df_mean.collect() # Create a DataFrame with stats for the groups funcs = { @@ -786,25 +790,41 @@ class Stats(Loggable): for name, func in stats.items() } if funcs: - grouped = df.groupby(tag_cols, observed=True, sort=False, group_keys=False) - df = grouped[self._val_col].agg(**funcs).reset_index() + df = df.group_by(tag_cols, maintain_order=True).agg( + pl.col(self._val_col).median().alias('median')) + tag_cols.append('median') + df = df.group_by(tag_cols, maintain_order=True).len(name='count') # Transform the newly created stats columns into rows - df = self._melt(df) + df = self._melt(df.collect()) else: - df = pd.DataFrame() + # Polars dataframe ? + df = pl.DataFrame() - df = pd.concat([df, df_mean]) + df = pl.concat([df, df_mean], how='diagonal') df = self._df_remove_tweak_cols(df) unit_col = self._unit_col default_unit = '' - if unit_col in df: - df[unit_col] = df[unit_col].fillna(default_unit) + + if unit_col in df.columns: + # Fill missing values (None) in the 'unit' column with default_unit + df = df.with_columns([ + pl.col(unit_col).fill_null(default_unit).alias(unit_col) + ]) else: - df[unit_col] = default_unit + # If column doesn't exist, create it and assign the default_unit + df = df.with_columns([ + pl.lit(default_unit).alias(unit_col) + ]) + # Create the condition for replacing values for stat, unit in self._STATS_UNIT.items(): - df.loc[df[self._stat_col] == stat, unit_col] = unit.name + df = df.with_columns( + pl.when(pl.col(self._stat_col) == stat) + .then(pl.lit(unit.name)) + .otherwise(pl.col(unit_col)) + .alias(unit_col) + ) return df @@ -824,14 +844,19 @@ class Stats(Loggable): if ref is None: return None else: - return pd.DataFrame({stat_name: [get_pval(ref, df)]}) + # Polars dataframe + return pl.DataFrame({stat_name: [get_pval(ref, df)]}) # Summarize each group by the p-value of the test against the reference group - test_df = self._df_group_apply(self._orig_df, func, melt=True) - test_df[self._unit_col] = 'pval' + test_df = self._df_group_apply(self._orig_df.collect(), func, melt=True) + + # Getting error with this ! + # test_df = test_df.with_columns( + # pl.lit('pval').alias(self._unit_col) + # ) test_df = self._df_remove_tweak_cols(test_df) - return pd.concat([df, test_df], ignore_index=True) + return pl.concat([df, test_df], how='diagonal') @_needs_ref def _df_compare_pct(self, df): @@ -869,9 +894,14 @@ class Stats(Loggable): (self._ref_group.keys() | {val_col}) ) df = self._df_group_apply(df, diff_pct, index_cols=index_cols) + + if df.is_empty(): + return df # Divisions can end up yielding extremely small values like 1e-14, # which seems to create problems while plotting - df[val_col] = df[val_col].round(10) + df = df.with_columns( + pl.col(val_col).round(10).alias(val_col) # Round the column to 10 decimal places + ) return df def _plot(self, df, title, plot_func, facet_rows, facet_cols, collapse_cols, filename=None, backend=None): @@ -894,10 +924,10 @@ class Stats(Loggable): if k in group_keys ) - if subdf.empty: + if subdf.is_empty(): fig = hv.Empty() else: - subdf = subdf.drop(columns=group_on) + subdf = subdf.drop(group_on) subdf = self._collapse_cols(subdf, collapse_group) fig = plot_func(subdf, collapsed_col, group_dict) @@ -927,9 +957,12 @@ class Stats(Loggable): collapsed_col = None collapse_group = {} + if isinstance(df, pl.LazyFrame): + df = df.collect() + subplots = dict( plot_subdf(group, subdf) - for group, subdf in df.groupby(group_on, observed=True, group_keys=False) + for group, subdf in df.group_by(group_on) ) kdims = sorted(set(itertools.chain.from_iterable( @@ -938,7 +971,7 @@ class Stats(Loggable): ))) if facet_cols: - ncols = len(df.drop_duplicates(subset=facet_cols, ignore_index=True)) + ncols = len(df.select(facet_cols).unique()) else: ncols = 1 @@ -1020,8 +1053,13 @@ class Stats(Loggable): mean_suffix = ' (CL: {:.1f}%)'.format( self._mean_ci_confidence * 100 ) - df = df.copy() - df.loc[df[self._stat_col] == 'mean', self._stat_col] += mean_suffix + df = df.clone() + df = df.with_columns( + pl.when(df[self._stat_col] == 'mean') + .then(df[self._stat_col] + mean_suffix) + .otherwise(df[self._stat_col]) + .alias(self._stat_col) + ) pretty_ref_group = ' and '.join( f'{k}={v}' @@ -1071,7 +1109,7 @@ class Stats(Loggable): if collapsed_col is None: collapsed_col = make_unique_col('group') collapsed_col_hover = '' - df = df.copy(deep=False) + df = df.clone() df[collapsed_col] = '' else: collapsed_col_hover = collapsed_col @@ -1081,13 +1119,13 @@ class Stats(Loggable): df[col] for col in self._ci_cols ] - except KeyError: + except pl.exceptions.ColumnNotFoundError: ci_cols = None else: # Avoid warning from numpy inside matplotlib when there is no # confidence interval value at all if all( - series.isna().all() + series.is_null().all() for series in error ): ci_cols = None @@ -1106,20 +1144,24 @@ class Stats(Loggable): show_unit = True tooltip_val_name = y_col + try: - unit, = df[unit_col].unique() - except ValueError: + # Attempt to get the unique values in the unit_col column + unit = df[unit_col].unique().to_list() + if len(unit) == 1: + unit = unit[0].strip() + if unit: + show_unit = False + tooltip_val_name = unit + except Exception: pass - else: - unit = unit.strip() - if unit: - show_unit = False - tooltip_val_name = unit - - df[value_str_col] = df.apply( - functools.partial(make_val_hover, show_unit), - axis=1 + + df = df.with_columns( + pl.struct(pl.all()) + .map_elements(functools.partial(make_val_hover, show_unit), return_dtype=pl.Float64) + .alias(value_str_col) ) + hover = HoverTool( tooltips=[ (collapsed_col_hover, f'@{collapsed_col}'), @@ -1127,11 +1169,14 @@ class Stats(Loggable): ] ) - bar_df = df[[collapsed_col, y_col, value_str_col]].dropna( - subset=[collapsed_col] - ) + df = df.fill_nan(None) + + bar_df = df.select([collapsed_col, y_col, value_str_col]).drop_nulls(subset=[collapsed_col]) + + dataset = bar_df.select([collapsed_col, y_col, value_str_col]).drop_nulls(subset=[collapsed_col]) + # Holoviews barfs on empty data for Bars - if bar_df.empty: + if bar_df.is_empty(): # TODO: should be replaced by hv.Empty() but this raises an # exception fig = hv.Curve([]).options( @@ -1140,7 +1185,7 @@ class Stats(Loggable): ) else: fig = hv.Bars( - bar_df[[collapsed_col, y_col, value_str_col]].dropna(subset=[collapsed_col]), + dataset.to_pandas(), ).options( ylabel='', xlabel='', @@ -1170,9 +1215,12 @@ class Stats(Loggable): # Labels do not work with matplotlib unfortunately: # https://github.com/holoviz/holoviews/issues/4992 if backend != 'matplotlib': - df_label = df.copy(deep=False) + df_label = df.clone() # Center the label in the bar - df_label[y_col] = df_label[y_col] / 2 + df_label = df_label.with_columns( + (pl.col(y_col) / 2).alias(y_col) + ) + fig *= hv.Labels( df_label[[collapsed_col, y_col, value_str_col]], vdims=[value_str_col], @@ -1233,6 +1281,8 @@ class Stats(Loggable): @staticmethod def _trim_group(df, group): + if isinstance(df, pl.LazyFrame): + df = df.collect() trimmed = [ col for col in group @@ -1357,7 +1407,7 @@ class Stats(Loggable): except KeyError: unit = None - data = df[[x_col, y_col]].sort_values(x_col) + data = df[[x_col, y_col]].sort(x_col, nulls_last=True) return ( hv.Curve( diff --git a/lisa/wa/__init__.py b/lisa/wa/__init__.py index b4e39d54871e750bb26ac3342418144b892d64ad..d3cbf21dabbacdcb3ed1d5ce1bd79fcc90774f4a 100644 --- a/lisa/wa/__init__.py +++ b/lisa/wa/__init__.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # -# Copyright (C) 2020, Arm Limited and contributors. +# Copyright (C) 2025, Arm Limited and contributors. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. @@ -29,6 +29,7 @@ import pathlib import warnings from functools import lru_cache +import polars as pl import pandas as pd from wa import discover_wa_outputs, Status @@ -40,6 +41,9 @@ from lisa._git import find_shortest_symref, get_commit_message from lisa.trace import Trace def _df_concat(dfs): + return pl.concat(dfs, how='diagonal') + +def _df_concat_pandas(dfs): return pd.concat(dfs, ignore_index=True, copy=False, sort=False) @@ -106,7 +110,7 @@ class StatsProp: return [ col for col in cols - if col in df.columns + if col in df.collect_schema() ] df = self.df @@ -138,47 +142,7 @@ class StatsProp: return self.get_stats() -class WAOutput(StatsProp, Mapping, Loggable): - """ - Recursively parse a ``Workload Automation`` output, using registered - collectors (leaf subclasses of :class:`WACollectorBase`). The data - collected are accessible through a :class:`pandas.DataFrame` in "database" - format: - - * meaningless index - * all values are tagged using tag columns - - :param path: Path containing a Workload Automation output. - :type path: str - - :param kernel_path: Kernel source path. Used to resolve the name of the - kernel which ran the workload. - :param kernel_path: str - - **Example**:: - - wa_output = WAOutput('wa/output/path') - # Pick a specific collector. See also WAOutput.get_collector() - stats = wa_output['results'].stats - stats.plot_stats(filename='stats.html') - """ - - def __init__(self, path, kernel_path=None): - self.path = path - self.kernel_path = kernel_path - - collector_classes = { - cls.NAME: cls - for cls in get_subclasses(WACollectorBase, only_leaves=True) - } - auto_collectors = { - name: cls - for name, cls in collector_classes.items() - if not self._needs_params(cls) - } - self._auto_collectors = auto_collectors - self._available_collectors = collector_classes - +class WAOutputBase(StatsProp, Mapping, Loggable): def __hash__(self): """ Each instance is different, like regular objects, and unlike dictionaries. @@ -232,6 +196,8 @@ class WAOutput(StatsProp, Mapping, Loggable): for collector, e in exceps.items() ]) + if self.df_fmt == 'pandas-dataframe': + return _df_concat_pandas(dfs) return _df_concat(dfs) def get_collector(self, name, **kwargs): @@ -323,6 +289,39 @@ class WAOutput(StatsProp, Mapping, Loggable): return wa_outputs + def get_view(self, *args, **kwargs): + return _WAOutputView(self, *args, **kwargs) + +class _WAOutputView(WAOutputBase): + def __init__(self, wa_output, df_fmt=None): + self.__wa_output = wa_output + self.__df_fmt = df_fmt + + @property + def df_fmt(self): + return self.__df_fmt or self.__wa_output.df_fmt + + def __getattr__(self, attr): + return getattr(self.__wa_output, attr) + +class WAOutput(WAOutputBase): + def __init__(self, path, kernel_path=None, df_fmt='polars-lazyframe'): + self.df_fmt = df_fmt + self.path = path + self.kernel_path = kernel_path + + collector_classes = { + cls.NAME: cls + for cls in get_subclasses(WACollectorBase, only_leaves=True) + } + auto_collectors = { + name: cls + for name, cls in collector_classes.items() + if not self._needs_params(cls) + } + self._auto_collectors = auto_collectors + self._available_collectors = collector_classes + class WACollectorBase(StatsProp, Loggable, abc.ABC): """ Base class for all ``Workload Automation`` dataframe collectors. @@ -383,9 +382,14 @@ class WACollectorBase(StatsProp, Loggable, abc.ABC): @memoized def df(self): """ - :class:`pandas.DataFrame` containing the data collected. + :class:`polars.DataFrame` or `pandas.Dataframe` containing the data collected. """ - return self._get_df() + df_fmt = self.wa_output.df_fmt + df = self._get_df() + if df_fmt == 'pandas-dataframe': + df = df.collect().to_pandas() + return df + return df def _get_df(self): self.logger.debug(f"Collecting dataframe for {self.NAME}") @@ -404,10 +408,13 @@ class WACollectorBase(StatsProp, Loggable, abc.ABC): if self._PURE_GET_JOB_DF: try: - df = pd.read_parquet(cache_path) + # Dataframe read + df = pl.scan_parquet(cache_path) + df.clear().collect() except OSError: df = get_df() - df.to_parquet(cache_path) + # Dataframe write + df.sink_parquet(cache_path) return df else: return get_df() @@ -445,21 +452,28 @@ class WACollectorBase(StatsProp, Loggable, abc.ABC): @staticmethod def _add_job_info(job, df): - df['iteration'] = job.iteration - df['workload'] = job.label - df['id'] = job.id - df = df.assign(**job.classifiers) + df = df.with_columns([ + pl.lit(job.iteration).alias('iteration'), + pl.lit(job.label).alias('workload'), + pl.lit(job.id).alias('id') + ]) + + for key, value in job.classifiers.items(): + df = df.with_columns(pl.lit(value).alias(key)) return df @staticmethod def _add_output_info(wa_output, name, df): # Kernel version kver = wa_output.target_info.kernel_version - df['kernel_name'] = kver.release - df['kernel_sha1'] = kver.sha1 + df = df.with_columns([ + pl.lit(kver.release).alias("kernel_name"), + pl.lit(kver.sha1).alias("kernel_sha1") + ]) # Folder of origin - df['wa_path'] = name + df = df.with_columns(pl.lit(name).alias("wa_path")) + return df def _add_kernel_id(self, df): @@ -479,7 +493,7 @@ class WACollectorBase(StatsProp, Loggable, abc.ABC): kernel_ids = { sha1: resolve_readable(sha1) - for sha1 in df['kernel_sha1'].unique() + for sha1 in df.select("kernel_sha1").collect().unique().to_series().to_list() if sha1 is not None } @@ -491,10 +505,13 @@ class WACollectorBase(StatsProp, Loggable, abc.ABC): for sha1, ref in kernel_ids.items() } - df['kernel'] = df['kernel_sha1'].map(kernel_ids).fillna( - df['kernel_name'] - ) - df.drop(columns=['kernel_sha1', 'kernel_name'], inplace=True) + df = df.with_columns( + kernel=pl.col("kernel_sha1") + .replace_strict(kernel_ids, default=None) + .fill_null(pl.col("kernel_name"))) + + df = df.drop(['kernel_sha1', 'kernel_name']) + return df @@ -507,7 +524,7 @@ class WAResultsCollector(WACollectorBase): @classmethod def _get_job_df(cls, job): - df = pd.DataFrame.from_records( + df = pl.LazyFrame( { **metric.classifiers, 'metric': metric.name, @@ -577,19 +594,19 @@ class WAEnergyCollector(WAArtifactCollectorBase): ) def _get_artifact_df(self, path): - df = pd.read_csv(path) + df = pd.scan_csv(path) # Record the CSV line as sample nr so each measurement is uniquely # identified by a metric and a sample number - df.index.name = 'sample' - df.reset_index(inplace=True) + df = df.with_row_index('sample') - df = df.melt(id_vars=['sample'], var_name='metric') + df = df.unpivot(index='sample', on='metric') suffix_unit = self._ARTIFACT_METRIC_SUFFIX_UNIT - df['unit'] = df['metric'].apply( - lambda x: suffix_unit[x.rsplit('_', 1)[-1]] - ) + df = df.with_columns( + unit=pl.col('metric') + .map_elements(lambda x: suffix_unit[x.rsplit('_', 1)[-1]])) + return df def get_stats(self, **kwargs): @@ -698,12 +715,18 @@ class WAJankbenchCollector(WAArtifactCollectorBase): _ARTIFACT_NAME = "jankbench-results" def _get_artifact_df(self, path): with contextlib.closing(sqlite3.connect(path)) as con: - raw_df = pd.read_sql_query("SELECT total_duration, jank_frame, name, _id as frame_id from ui_results", con) + raw_df = pd.read_database( + "SELECT total_duration, jank_frame, name, _id as frame_id from ui_results", con).lazy() + + df = raw_df.unpivot(index=['frame_id'], on=['total_duration', 'jank_frame']) - df = raw_df.melt(id_vars=['frame_id'], value_vars=['total_duration', 'jank_frame']) # supply units - everything is ms time except jank frames - df['unit'] = 'ms' - df.loc[df['variable'] == 'jank_frame', 'unit'] = '' + df = df.with_columns( + pl.when(pl.col("variable") == 'jank_frame') + .then(pl.lit('')) + .otherwise(pl.lit('ms')).alias("unit") + ) + return df def get_stats(self, **kwargs): @@ -792,7 +815,7 @@ class WASysfsExtractorCollector(WAArtifactCollectorBase): raw_file = [line.strip() for line in raw_file] - df = pd.DataFrame(data={ + df = pd.Lazyframe(data={ 'variable': self._filename, 'value': raw_file, 'unit': ''