Source code for tsunami_ip_utils.viz.scatter_plot

"""Tools for creating interactive and static scatter plots with error bars, linear regression lines, and correlation coefficient
calculations."""

from __future__ import annotations
import scipy.stats as stats
import numpy as np
import matplotlib.pyplot as plt
from ._base_plotter import _Plotter
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import plotly.graph_objects as go
import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State
import pandas as pd
import webbrowser
import os
import signal
import pickle
import threading
from .plot_utils import _find_free_port
import sys, os, signal
import threading
import webbrowser
import sys
from plotly.graph_objs import Figure
from pyparsing import *
import re
from uncertainties import ufloat
import tsunami_ip_utils
from typing import Tuple, Any, Union, List
import typing
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from pathlib import Path

[docs] def _replace_spearman_and_pearson(text: str, new_pearson: float, new_spearman: float) -> str: """Replaces the Spearman and Pearson values in the given text with the new values provided using regex. This is useful for updating the correlation statistics in a scatter plot annotation interactively after the plot has been created. Parameters ---------- text The text (scatter plot annotation) containing the Spearman and Pearson values to be updated. It is expected that the text contains the Spearman and Pearson values in the format ``"Spearman: <b>0.123456</b> Pearson: <b>0.123456</b>"``. new_pearson The new Pearson correlation coefficient value to replace the old value in the text. new_spearman The new Spearman rank correlation coefficient value to replace the old value in the text. Returns ------- The updated text with the new Spearman and Pearson values.""" new_pearson = f"{new_pearson:1.6f}" new_spearman = f"{new_spearman:1.6f}" # Replace Pearson value using regex pearson_pattern = r"(Pearson: <b>)[0-9.]+(</b>)" text = re.sub(pearson_pattern, r"\g<1>" + new_pearson + r"\2", text) # Replace Spearman value using regex spearman_pattern = r"(Spearman: <b>)[0-9.]+(</b>)" text = re.sub(spearman_pattern, r"\g<1>" + new_spearman + r"\2", text) return text
[docs] def _update_percent_difference(text: str, reference_value: float) -> Tuple[str, ufloat]: """Updates the percent difference in the given text with the new percent difference calculated using the new Pearson correlation coefficient value. This is useful for updating the percent difference in a scatter plot annotation interactively after the plot has been created. This is only valid for scatter plots generated by :func:`tsunami_ip_utils.comparisons.correlation_comparison` or other plots that have the percent difference in the annotation. Parameters ---------- text The text (scatter plot annotation) containing the percent difference to be updated. It is expected that the text contains the percent difference in the format ``"Percent Difference: <b>0.123456</b>%"`` and the TSUNAMI-IP c_k value in the format ``"TSUNAMI-IP c_k: <b>0.123456+/-0.123456</b>"``. reference_value The reference value to calculate the percent difference from. This is typically the (updated) Pearson correlation coefficient. Returns ------- - updated_text The updated text with the new percent difference value. - percent_difference The new percent difference value with uncertainty.""" # Regex to find the TSUNAMI-IP c_k value ck_pattern = r"TSUNAMI-IP c_k: <b>([\d\.]+)\+/-([\d\.]+)</b>" # Find the TSUNAMI-IP c_k value in the text match = re.search(ck_pattern, text) if not match: raise ValueError("TSUNAMI-IP c_k value not found in the text") # Parse the nominal value and the uncertainty nominal_value, uncertainty = match.groups() tsunami_ck = ufloat(float(nominal_value), float(uncertainty)) # Calculate the percent difference if reference_value == 0 or np.isnan(reference_value): percent_difference = ufloat(0, 0) else: percent_difference = (tsunami_ck - reference_value) / reference_value * 100 # Format the new percent difference with uncertainty new_percent_diff_text = f"<b>{percent_difference.nominal_value:.2f}+/-{percent_difference.std_dev:.2f}</b>%" # Regex to replace the Percent Difference in the text # Escaping special characters used in the regex pattern updated_text = re.sub(r"Percent Difference: <b>[\d\.\+\-/]+</b>+%", f"Percent Difference: {new_percent_diff_text}", text) return updated_text, percent_difference
[docs] class EnhancedPlotlyFigure(Figure): """This class wraps a plotly express figure object (intended for a scatter plot) and adds additional attributes for the summary statistics and linear regression data. This class is intended to be used with the :class:`.InteractiveScatterPlotter` class.""" statistics: dict """A dictionary containing the Pearson and Spearman correlation coefficients.""" regression: dict """A dictionary containing the slope and intercept of the linear regression line."""
[docs] def __init__(self, *args, **kwargs): """Initializes an EnhancedPlotlyFigure object from a Plotly Express figure object. Intended Use ============ This class should be initialized from a plotly expres Figure object via: .. code-block:: python fig = px.scatter(...) fig = EnhancedPlotlyFigure(fig.to_dict()) Additional ``statistics`` and ``regression`` attributions can be added to this "enhanced" figure via .. code-block:: python fig.statistics = { 'pearson': 0.123456, 'spearman': 0.123456 } fig.regression = { 'slope': 0.123456, 'intercept': 0.123456 } unlike a regular Plotly Express figure object, which will throw an error if you try to add these attributes directly.""" super().__init__(*args, **kwargs) # Directly set the attributes using object's __setattr__ to bypass Plotly's checks object.__setattr__(self, 'statistics', None) object.__setattr__(self, 'regression', None)
def __setattr__(self, name, value): """Sets the attribute of the object. This method is overridden to allow setting the ``'statistics'`` and ``'regression'`` attributes directly without raising an error. All other attributes are set using the super class's __setattr__ method.""" if name in ['statistics', 'regression']: # Handle custom attributes internally object.__setattr__(self, name, value) else: # Use the super class's __setattr__ for all other attributes super().__setattr__(name, value)
[docs] class _ScatterPlot(_Plotter): """This class exists to add some additional functionality for calculating regressions and summary statistics that's common to all types of scatter plots, interactive or otherwise""" _regression: Any """A named tuple (``scipy.stats`` ``Linregress`` object) containing the slope and intercept of the linear regression line as attributes ``slope`` and ``intercept`` respectively.""" _pearson: float """The Pearson correlation coefficient.""" _spearman: float """The Spearman rank correlation coefficient.""" _summary_stats_text: str """A string containing the summary statistics for the scatter plot."""
[docs] def _get_summary_statistics(self, x: Union[List, np.ndarray], y: Union[List, np.ndarray]) -> None: """Calculates the Pearson correlation coefficient, Spearman rank correlation coefficient, and linear regression parameters for the given x and y datasets. The linear regression parameters are the slope and intercept of the regression line. The Pearson and Spearman coefficients are also stored in the class instance as 'pearson' and 'spearman' respectively. The slope and intercept are stored as 'slope' and 'intercept' respectively. The linear regression is stored as 'regression' Parameters ---------- x The x values of the scatter plot. y The y values of the scatter plot.""" self._regression = stats.linregress(x, y) self._pearson = stats.pearsonr(x, y).statistic self._spearman = stats.spearmanr(x, y).statistic self._slope = self._regression.slope self._intercept = self._regression.intercept # If the figure has been plotted (and is an enhanced plot which supports adding this metadata), add the regression # and correlation statistics to the figure is_enhanced_fig = isinstance(getattr(self, 'fig', None), EnhancedPlotlyFigure) if hasattr(self, 'fig') and is_enhanced_fig: self.fig.statistics = { 'pearson': self._pearson, 'spearman': self._spearman } self.fig.regression = { 'slope': self._slope, 'intercept': self._intercept, } # Now create the summary statistics text for figure annotation pearson_text = f"Pearson: <b>{self._pearson:1.6f}</b>" spearman_text = f"Spearman: <b>{self._spearman:1.6f}</b>" self._summary_stats_text = f"{pearson_text} {spearman_text}"
[docs] class _ScatterPlotter(_ScatterPlot): """A class for creating static scatter plots with error bars, linear regression lines, and correlation coefficient calculations.""" _nested: bool """Whether the contributions are nested or not. This is unused in this class.""" _index_name: str """The name of the integral index (whose contributions) being plotted. This is used only for the title of the plot.""" _plot_redundant: bool """Whether to plot redundant reactions or not. This is unused in this class.""" def __init__(self, integral_index_name: str, nested: bool, plot_redundant: bool=False, **kwargs: dict) -> None: """Initializes a ScatterPlotter object with the given integral index name, nested status, and plot redundant reactions status. The nested and plot_redundant arguments are unused in this class. Parameters ---------- integral_index_name The name of the integral index (whose contributions) being plotted. nested Whether the contributions are nested or not. plot_redundant Whether to plot redundant reactions or not. kwargs Additional keyword arguments. These are unused in this class.""" self._nested = nested self._index_name = integral_index_name self._plot_redundant = plot_redundant
[docs] def _create_plot(self, contribution_pairs: List[ufloat], isotopes: List[str], reactions: List[str]) -> None: """Creates a static scatter plot with error bars, linear regression line, and correlation coefficient calculations. Parameters ---------- contribution_pairs A list of pairs of contributions to the integral index from the application and experiment. isotopes The list of isotopes represented by each contribution pair. This has the same length and it ordered the same as the contribution_pairs. reactions The list of reactions represented by each contribution pair. This has the same length and it ordered the same as the contribution_pairs. Notes ----- The set of isotopes and reactions is only necessary for creating labels in interactive plots (e.g. those made by :class:`.InteractiveScatterPlotter`), but is included here for consistency.""" self.fig, self.axs = plt.subplots() # Extract the x and y values from the contribution pairs application_points = [ contribution[0].n for contribution in contribution_pairs ] application_uncertainties = [ contribution[0].s for contribution in contribution_pairs ] experiment_points = [ contribution[1].n for contribution in contribution_pairs ] experiment_uncertainties = [ contribution[1].s for contribution in contribution_pairs ] self.fig = plt.errorbar(application_points, experiment_points, xerr=application_uncertainties, \ yerr=experiment_uncertainties, fmt='.', capsize=5) # Linear regression self._get_summary_statistics(application_points, experiment_points) # Plot the regression line x = np.linspace(min(application_points), max(application_points), 100) y = self._slope * x + self._intercept self.axs.plot(x, y, 'r', label='Linear fit') self.axs.text(0.05, 0.95, self._summary_stats_text, transform=self.axs.transAxes, fontsize=12, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.5)) self._style()
[docs] def _get_plot(self) -> Tuple[Figure, Axes]: return self.fig, self.axs
[docs] def _add_to_subplot(self, fig, position) -> Figure: return fig.add_subplot(position, sharex=self.axs, sharey=self.axs)
[docs] def _style(self): title_text = f'Contributions to {self._index_name}' self.axs.set_title(title_text) self.axs.set_ylabel(f"Experiment {self._index_name} Contribution") self.axs.set_xlabel(f"Application {self._index_name} Contribution") self.axs.grid()
[docs] def load_interactive_scatter_plot(filename: Union[str, Path]) -> InteractiveScatterLegend: """Loads an interactive scatter plot (with interactive legend) from a saved state file. This function is purely for convenience and is a wrapper of the :meth:`.InteractiveScatterLegend.load_state` method Parameters ---------- filename The filename of the saved state file to load the interactive scatter plot from. Returns ------- A reseralized instance of the interactive scatter plot that can be plotted with :meth:`InteractiveScatterLegend.show()`.""" return InteractiveScatterLegend.load_state(filename)
[docs] class _InteractiveScatterPlotter(_ScatterPlot): """A class for creating `interactive` scatter plots with error bars, linear regression lines, and correlation coefficient calculations.""" _interactive_legend: bool """Whether to include an interactive legend in the plot. This is used to toggle the visibility of the traces in the plot and interactively recalculate the regression and summary statistics.""" _nested: bool """Whether the contributions are nested or not. This is unused in this class.""" _index_name: str """The name of the integral index (whose contributions) being plotted. This is used only for the title of the plot.""" _plot_redundant: bool """Whether to plot redundant reactions or not. This is unused in this class.""" def __init__(self, integral_index_name: str, nested: bool, plot_redundant: bool=False, **kwargs: dict): """Initializes an ``InteractiveScatterPlotter`` object with the given options. Parameters ---------- integral_index_name The name of the integral index (whose contributions) being plotted. nested Whether the contributions are nested or not. plot_redundant Whether to plot redundant reactions or not. kwargs Additional keyword arguments. - interactive_legend (bool) Whether to include an interactive legend in the plot. Default is ``False``.""" if 'interactive_legend' in kwargs.keys(): self._interactive_legend = kwargs['interactive_legend'] else: self._interactive_legend = False self._nested = nested self._index_name = integral_index_name self._plot_redundant = plot_redundant
[docs] def _create_plot(self, contribution_pairs: List[ufloat], isotopes: List[str], reactions: List[str]) -> None: """Create an interactive scatter plot with error bars, linear regression line, and correlation coefficient calculations. Parameters ---------- contribution_pairs A list of pairs of contributions to the integral index from the application and experiment. isotopes The list of isotopes represented by each contribution pair. This has the same length and it ordered the same as the ``contribution_pairs``. reactions The list of reactions represented by each contribution pair. This has the same length and it ordered the same as the ``contribution_pairs``.""" self.fig = make_subplots() # Extract isotope and reaction pairs from the given list of isotopes and reactions df = self._create_scatter_data(contribution_pairs, isotopes, reactions) hover_data_dict = { 'Isotope': True # Always include Isotope } if 'Reaction' in df.columns: hover_data_dict['Reaction'] = True # Include Reaction only if it exists # Create scatter plot with error bars using Plotly Express self.fig = px.scatter( df, x=f'Application {self._index_name} Contribution', y=f'Experiment {self._index_name} Contribution', error_x='Application Uncertainty', error_y='Experiment Uncertainty', color='Isotope', labels={ "color": "Isotope" }, title=f'Contributions to {self._index_name}', hover_data=hover_data_dict ) # Wrap the plotly express figure in a MetadataPlotly object self.fig = EnhancedPlotlyFigure(self.fig.to_dict()) self._add_regression_and_stats(df) # Now style the plot self._style() if self._interactive_legend: self.fig = InteractiveScatterLegend(self, df)
[docs] def _add_regression_and_stats(self, df: pd.DataFrame) -> None: """Add the linear regression line and correlation statistics to the plot. This method is called after the plot has been created. This method can also update a comparison with a TSUNAMI-IP calculated :math:`c_k` value if it exists in the annotation. Parameters ---------- df The DataFrame containing the data for the scatter plot. This DataFrame should contain the columns ``'Application <integral_index_name> Contribution'`` and ``'Experiment <integral_index_name> Contribution'``. """ # Calculate the linear regression and correlation statistics self._get_summary_statistics(df[f'Application {self._index_name} Contribution'], \ df[f'Experiment {self._index_name} Contribution']) # Prepare data for the regression line x_reg = np.linspace(df[f'Application {self._index_name} Contribution'].min(), df[f'Application {self._index_name} Contribution'].max(), 100) y_reg = self._slope * x_reg + self._intercept # Convert self.fig.data to a list for mutability current_traces = list(self.fig.data) # Remove existing regression line if it exists traces_to_keep = [trace for trace in current_traces if not trace.name.startswith('Regression Line')] # Set the modified list of traces back to the figure self.fig.data = tuple(traces_to_keep) # Add new linear regression to the plot self.fig.add_trace(go.Scatter(x=x_reg, y=y_reg, mode='lines', name=f'Regression Line y={self._slope:1.4E}x + {self._intercept:1.4E}')) # Remove existing annotation if it exists if hasattr(self.fig, 'layout') and hasattr(self.fig.layout, 'annotations'): if len(self.fig.layout.annotations) != 0: # Get the text of the old annotation old_annotation_text = \ [ann for ann in self.fig.layout.annotations if ann.text.startswith('Pearson')][0].text annotation_text = _replace_spearman_and_pearson(old_annotation_text, self._pearson, self._spearman) bordercolor = '#444' self.fig.layout.annotations = [ann for ann in self.fig.layout.annotations if not ann.text.startswith('Pearson')] if "TSUNAMI-IP" in annotation_text: try: annotation_text, percent_difference = _update_percent_difference(annotation_text, self._pearson) except Exception as e: print(f"Error updating percent difference: {e}") if abs(percent_difference.nominal_value) > 5: bordercolor = 'red' else: bordercolor = '#444' annotation_text = self._summary_stats_text else: bordercolor = '#444' annotation_text = self._summary_stats_text # Add correlation statistics to the plot self.fig.add_annotation( x=0.05, xref="paper", y=0.95, yref="paper", text=annotation_text, bordercolor=bordercolor, showarrow=False, font=dict(size=12), align='left', bgcolor="white", opacity=0.8 )
[docs] def _create_scatter_data(self, contribution_pairs: List[ufloat], isotopes: List[str], reactions: List[str]) -> pd.DataFrame: """Create a DataFrame from the given contribution pairs, isotopes, and reactions. This DataFrame is used to create the interactive scatter plot. Parameters ---------- contribution_pairs A list of pairs of contributions to the integral index from the application and experiment. isotopes The list of isotopes represented by each contribution pair. This has the same length and it ordered the same as the ``contribution_pairs``. reactions The list of reactions represented by each contribution pair. This has the same length and it ordered the same as the ``contribution_pairs``.""" data = { f'Application {self._index_name} Contribution': [cp[0].n for cp in contribution_pairs], f'Experiment {self._index_name} Contribution': [cp[1].n for cp in contribution_pairs], 'Application Uncertainty': [cp[0].s for cp in contribution_pairs], 'Experiment Uncertainty': [cp[1].s for cp in contribution_pairs], 'Isotope': [], } # Add nuclides and reactions (if they exist) to the data dictionary if reactions == []: for isotope in isotopes: data['Isotope'].append(isotope) else: data['Reaction'] = [] for isotope in isotopes: for reaction in reactions: data['Isotope'].append(isotope) data['Reaction'].append(reaction) # Now filter out (0,0) points, which don't contribute to either the application or the experiment, these are # usually chi, nubar, or fission reactions for nonfissile isotopes that are added for consistency with the set # of reactions only data = { key: [val for val, app, exp in zip(data[key], data[f'Application {self._index_name} Contribution'], \ data[f'Experiment {self._index_name} Contribution']) if app != 0 or exp != 0] for key in data } return pd.DataFrame(data)
[docs] def _add_to_subplot(self, fig, position): for trace in self.fig.data: fig.add_trace(trace, row=position[0], col=position[1]) # Transfer annotations if hasattr(self.fig, 'layout') and hasattr(self.fig.layout, 'annotations'): for ann in self.fig.layout.annotations: # Adjust annotation references to new subplot new_ann = ann.update(xref=f'x{position[1]}', yref=f'y{position[1]}') fig.add_annotation(new_ann, row=position[0], col=position[1]) return fig
[docs] def _get_plot(self): return self.fig
[docs] def _style(self): title_text = f'Contributions to {self._index_name}' self.fig.update_layout(title_text=title_text, title_x=0.5) # 'title_x=0.5' centers the title
[docs] class _InteractivePerturbationScatterPlotter(_ScatterPlot): """Class for creating an interactive scatter plot using the nuclear data sampling method (where perturbed cross section libraries are used to calculate sample points on the scatter plot).""" def __init__(self, **kwargs: dict): pass
[docs] def _create_plot(self, points: List[Tuple[ufloat, ufloat]]) -> None: """Create an interactive perturbation scatter plot with error bars, linear regression line, and correlation coefficient. Parameters ---------- points A list of perturbation points, each computed from sampled perturbed cross section libraries. These points are generated using the :func:`tsunami_ip_utils.perturbations.generate_points` function.""" self.fig = make_subplots() # Extract isotope and reaction pairs from the given list of isotopes and reactions df = pd.DataFrame({ 'Application': [point[0].n for point in points], 'Experiment': [point[1].n for point in points], 'Application Uncertainty': [point[0].s for point in points], 'Experiment Uncertainty': [point[1].s for point in points] }) # Create scatter plot with error bars using Plotly Express self.fig = px.scatter( df, x=f'Application', y=f'Experiment', error_x='Application Uncertainty', error_y='Experiment Uncertainty', title=f'Correlation Plot', ) # Wrap the plotly express figure in a MetadataPlotly object self.fig = EnhancedPlotlyFigure(self.fig.to_dict()) self._add_regression_and_stats(df) # Now style the plot self._style()
[docs] def _add_regression_and_stats(self, df: pd.DataFrame) -> None: """Add the linear regression line and correlation statistics to the plot. This method is called after the plot has been created. Parameters ---------- df The DataFrame containing the data for the scatter plot. This DataFrame should contain the columns ``'Application'`` and ``'Experiment'``.""" # Calculate the linear regression and correlation statistics self._get_summary_statistics(df[f'Application'], df[f'Experiment']) # Prepare data for the regression line x_reg = np.linspace(df[f'Application'].min(), df[f'Application'].max(), 100) y_reg = self._slope * x_reg + self._intercept # Convert self.fig.data to a list for mutability current_traces = list(self.fig.data) # Remove existing regression line if it exists traces_to_keep = [trace for trace in current_traces if not trace.name.startswith('Regression Line')] # Set the modified list of traces back to the figure self.fig.data = tuple(traces_to_keep) # Add new linear regression to the plot self.fig.add_trace(go.Scatter(x=x_reg, y=y_reg, mode='lines', name=f'Regression Line y={self._slope:1.4E}x + {self._intercept:1.4E}')) # Add correlation statistics to the plot self.fig.add_annotation( x=0.05, xref="paper", y=0.95, yref="paper", text=self._summary_stats_text, showarrow=False, align='left', font=dict(size=12), bgcolor="white", opacity=0.8 )
[docs] def _add_to_subplot(self, fig, position): for trace in self.fig.data: fig.add_trace(trace, row=position[0], col=position[1]) return fig
[docs] def _get_plot(self): return self.fig
[docs] def _style(self): pass
[docs] class InteractiveScatterLegend(_InteractiveScatterPlotter): """An implementation of an interactive legend (that automatically updates the regression and summary statistics following the exclusion of certain data in the scatter plot) for an interactive (Plotly) scatter plot.""" fig: EnhancedPlotlyFigure """The Plotly figure object containing the interactive scatter plot.""" _interactive_scatter_plot: _InteractiveScatterPlotter """The interactive scatter plot object that the interactive legend is associated with.""" df: pd.DataFrame """The DataFrame containing the data for the scatter plot.""" _excluded_isotopes: List[str] """A list of isotopes that have been excluded from the scatter plot.""" _app: dash.Dash """The Dash application object for the interactive legend."""
[docs] def __init__(self, interactive_scatter_plot: _InteractiveScatterPlotter, df: pd.DataFrame): self._interactive_scatter_plot = interactive_scatter_plot self.fig = interactive_scatter_plot.fig self._index_name = interactive_scatter_plot._index_name self.df = df self._excluded_isotopes = [] # Keep track of excluded isotopes self._app = dash.Dash(__name__) self._app.layout = html.Div([ dcc.Graph(id='interactive-scatter', figure=self.fig, style={'height': '100vh'}) ], style={'margin': 0}) self._setup_callbacks()
[docs] def _setup_callbacks(self): """Set up the Dash callbacks for the interactive legend.""" @self._app.callback( Output('interactive-scatter', 'figure'), Input('interactive-scatter', 'restyleData'), State('interactive-scatter', 'figure') ) def update_figure_on_legend_click(restyleData, current_figure_state): if restyleData and 'visible' in restyleData[0]: current_fig = go.Figure(current_figure_state) # Get the index of the clicked trace clicked_trace_index = restyleData[1][0] # Get the name of the clicked trace clicked_trace_name = current_fig.data[clicked_trace_index].name # Update excluded isotopes based on the clicked trace if restyleData[0]['visible'][0] == 'legendonly' and clicked_trace_name not in self._excluded_isotopes: self._excluded_isotopes.append(clicked_trace_name) elif restyleData[0]['visible'][0] == True and clicked_trace_name in self._excluded_isotopes: self._excluded_isotopes.remove(clicked_trace_name) # Update DataFrame based on excluded isotopes updated_df = self.df.copy() updated_df = updated_df[~updated_df['Isotope'].isin(self._excluded_isotopes)] # Recalculate the regression and summary statistics self._add_regression_and_stats(updated_df) # Update trace visibility based on excluded isotopes for trace in self.fig.data: if trace.name in self._excluded_isotopes: trace.visible = 'legendonly' else: trace.visible = True return self.fig return dash.no_update @self._app.server.route('/shutdown', methods=['POST']) def shutdown(): os.kill(os.getpid(), signal.SIGINT) # Send the SIGINT signal to the current process return 'Server shutting down...'
[docs] def show(self) -> None: """Display the interactive scatter plot with the interactive legend in a web browser.""" port = _find_free_port() # Function to open the browser def open_browser(): if not os.environ.get("WERKZEUG_RUN_MAIN"): print(f"Now running at http://localhost:{port}/") webbrowser.open(f"http://localhost:{port}/") # Silence the Flask development server logging log = open(os.devnull, 'w') # sys.stdout = log sys.stderr = log # Disable Flask development server warning os.environ['FLASK_ENV'] = 'development' # JavaScript code to detect when the tab or window is closed self._app.index_string = ''' <!DOCTYPE html> <html> <head> {%metas%} <title>{%title%}</title> {%favicon%} {%css%} </head> <body style="margin: 0;"> {%app_entry%} <footer> {%config%} {%scripts%} <script type="text/javascript"> window.addEventListener("beforeunload", function (e) { var xhr = new XMLHttpRequest(); xhr.open("POST", "/shutdown", false); xhr.send(); }); </script> {%renderer%} </footer> </body> </html> ''' # Timer to open the browser shortly after the server starts threading.Timer(1, open_browser).start() self._app.run_server(debug=False, host='localhost', port=port)
[docs] def save_state(self, filename: typing.Optional[Union[str, Path]]=None) -> Optional[dict]: """Save the current state of the interactive scatter plot to a file. This method saves the state of the interactive scatter plot as a dictionary containing the figure, DataFrame, excluded isotopes, index name, and nested status. This can be later reserialized by the :meth:`InteractiveScatterLegend.load_state` method, or the :func:`load_interactive_scatter_plot` function. Parameters ---------- filename The filename to save the state to. If not provided, the state dictionary is returned instead of being saved to a file. Returns ------- * If ``filename`` is not provided, the state dictionary is returned. * Otherwise, ``None`` is returned.""" state = { 'fig': self.fig.to_dict(), 'df': self.df.to_dict(), 'excluded_isotopes': self._excluded_isotopes, 'index_name': self._index_name, 'nested': self._interactive_scatter_plot._nested } if filename is None: return state else: with open(filename, 'wb') as f: pickle.dump(state, f)
[docs] @classmethod def load_state(cls, filename: typing.Optional[Union[str, Path]]=None, data_dict: typing.Optional[dict]=None ) -> InteractiveScatterLegend: """Load a saved state of an interactive scatter plot from a file. This method loads the state of the interactive scatter plot from a dictionary containing the figure, DataFrame, excluded isotopes, index name, and nested status. This can be reserialized by the :meth:`InteractiveScatterLegend.load_state` method, or the :func:`load_interactive_scatter_plot` function. Parameters ---------- filename The filename of the saved state file to load the interactive scatter plot from. data_dict A dictionary containing the saved state of the interactive scatter plot. This is an alternative to providing the filename. Returns ------- An instance of the interactive scatter plot that can be plotted with :meth:`InteractiveScatterLegend.show()`. """ if filename is None and data_dict is None: raise ValueError("Either a filename or a data dictionary must be provided") if filename is not None: with open(filename, 'rb') as f: state = pickle.load(f) else: state = data_dict # Recreate the _InteractiveScatterPlotter instance from the saved state fig = go.Figure(state['fig']) index_name = state['index_name'] nested = state['nested'] interactive_scatter_plot = _InteractiveScatterPlotter(index_name, nested) interactive_scatter_plot.fig = fig # Recreate the InteractiveScatterLegend instance from the saved state instance = cls(interactive_scatter_plot, pd.DataFrame.from_dict(state['df'])) instance._excluded_isotopes = state['excluded_isotopes'] # Update trace visibility based on excluded isotopes for trace in instance.fig.data: if trace.name in instance._excluded_isotopes: trace.visible = 'legendonly' else: trace.visible = True return instance
[docs] def write_html(self, filename: Union[str, Path]) -> None: """Save the current state of the interactive scatter plot to an HTML file. This method saves the current state of the interactive scatter plot to an HTML file that can be viewed in a web browser. Parameters ---------- filename The filename to save the interactive scatter plot to. Notes ----- Since the legend interactivity is implemented via python, saving the plot as HTML will only save the current state of the plot. The interactivity will not be preserved in the saved HTML file.""" # Utilize Plotly's write_html to save the current state of the figure self.fig.write_html(filename)