Source code for trtoolbox.plothelper

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from matplotlib.widgets import CheckButtons
import matplotlib.colors as colors
from scipy.interpolate import interp2d
from mpl_toolkits import mplot3d


[docs]class PlotHelper: """ Object for interactive plotting. This class ensures that all the matplotlib objects are kept referenced which ensures proper function of the sliders. Attributes ---------- fig_traces : matplotlib.figure.Figure ax_traces : matplotlib.axes.Axes l1_traces : matplotlib.lines.Line2D l2_traces: matplotlib.lines.Line2D axfreq : matplotlib.axes.Axes sfreq : matplotlib.widgets.Slider fig_spectra : matplotlib.figure.Figure ax_spectra : matplotlib.axes.Axes l1_spectra : matplotlib.lines.Line2D l2_spectra : matplotlib.lines.Line2D axtime : matplotlib.axes.Axes stime : matplotlib.widgets.Slider fig_lda : matplotlib.figure.Figure() axs_lda : np.array map_lda : matplotlib.contour.QuadContourSet ldadata : matplotlib.collections.QuadMesh l1_lda : matplotlib.lines.Line2D axalpha : matplotlib.axes.Axes salpha : matplotlib.widgets.Slider axcolor : str """ def __init__(self): self.fig_traces = matplotlib.figure.Figure() self.ax_traces = matplotlib.axes.Axes(self.fig_traces, [0, 0, 0, 0]) self.l1_traces = matplotlib.lines.Line2D([], []) self.l2_traces = matplotlib.lines.Line2D([], []) self.axfreq = matplotlib.axes.Axes(self.fig_traces, [0, 0, 0, 0]) self.sfreq = matplotlib.widgets.Slider(self.axfreq, '', 0, 1) self.fig_spectra = matplotlib.figure.Figure() self.ax_spectra = matplotlib.axes.Axes(self.fig_spectra, [0, 0, 0, 0]) self.l1_spectra = matplotlib.lines.Line2D([], []) self.l2_spectra = matplotlib.lines.Line2D([], []) self.axtime = matplotlib.axes.Axes(self.fig_spectra, [0, 0, 0, 0]) self.stime = matplotlib.widgets.Slider(self.axtime, '', 0, 1) self.fig_lda = matplotlib.figure.Figure() self.axs_lda = np.ndarray((2, 2)) # self.map_lda = matplotlib.contour.QuadContourSet( # matplotlib.axes.Axes(self.fig_lda, [0, 0, 0, 0]) # ) # self.ldadata = matplotlib.collections.QuadMesh() self.l1_lda = matplotlib.lines.Line2D([], []) self.axalpha = matplotlib.axes.Axes(self.fig_lda, [0, 0, 0, 0]) self.salpha = matplotlib.widgets.Slider(self.axalpha, '', 0, 1) self.axcolor = 'lightgoldenrodyellow'
[docs] @staticmethod def do_interpolate(data, time, wn, step=0.5): f = interp2d(time, wn, data, kind='cubic') time_new = time[0, :] wn_new = np.arange(wn[0, 0], wn[-1, 0], step) data_new = f(time_new, wn_new) time_mesh, wn_mesh = np.meshgrid(time_new, wn_new) return data_new, time_mesh, wn_mesh
[docs] def plot_heatmap(self, data, time, wn, title='data', newfig=True, interpolate=False, step=.5): """ Plots a nice looking heatmap. Parameters ---------- data : np.array Data matrix subjected to SVD. Assuming *m x n* with m as frequency and n as time. But it is actually not important. time : np.array Time array. wn : np.array Frequency array. title : np.array Title of plot. Default *data*. newfig : boolean Setting to False prevents the creation of a new figure. interpolate : boolean True for interpolation step : float Step size for frequency interpolation. """ if newfig is True: plt.figure() if data.shape[1] != time.size: data = np.transpose(data) if time.size == 0 or wn.size == 0: pc = plt.pcolormesh(data, cmap='jet', shading='gouraud') else: if interpolate is True: data, time, wn = self.do_interpolate(data, time, wn, step) pc = plt.pcolormesh( time, wn, data, cmap='jet', shading='gouraud', norm=MidpointNormalize(midpoint=0), ) plt.xscale('log') plt.title(title) return pc
[docs] @staticmethod def plot_contourmap(data, time, wn, title='data', newfig=True): """ Plots a nice looking contourmap. Parameters ---------- data : np.array Data matrix subjected to SVD. Assuming *m x n* with m as frequency and n as time. But it is actually not important. time : np.array Time array. wn : np.array Frequency array. title : np.array Title of plot. Default *data*. newfig : boolean Setting to False prevents the creation of a new figure. """ if newfig is True: plt.figure() # ensuring that time spans columns if data.shape[1] != time.size: data = np.transpose(data) levels = 10 if time.size == 0 or wn.size == 0: pc = plt.contourf( data, levels=levels, cmap='bwr', norm=MidpointNormalize(midpoint=0) ) pc = plt.contour( data, levels=levels, linewidths=0.3, colors='k' ) else: pc = plt.contourf( time.flatten(), wn.flatten(), data, levels=levels, cmap='bwr', norm=MidpointNormalize(midpoint=0) ) pc = plt.contour( time.flatten(), wn.flatten(), data, levels=levels, linewidths=0.1, colors='k' ) plt.xscale('log') plt.title(title) return pc
[docs] def plot_surface(self, data, time, wn, title='data', interpolate=False, step=.5): """ Plots a nice looking heatmap. Parameters ---------- data : np.array Data matrix subjected to SVD. Assuming *m x n* with m as frequency and n as time. But it is actually not important. time : np.array Time array. wn : np.array Frequency array. title : np.array Title of plot. Default *data*. interpolate : boolean True for interpolation step : float Step size for frequency interpolation. """ plt.figure() ax = plt.axes(projection="3d") # ensuring that time spans columns if data.shape[1] != time.size: data = np.transpose(data) if time.size == 0 or wn.size == 0: surf = plt.pcolormesh(data, cmap='jet', shading='gouraud') plt.xscale('log') else: if interpolate is True: data, x, y = self.do_interpolate(data, time, wn, step) x = np.log10(x) else: x, y = np.meshgrid(np.log10(time), wn) surf = ax.plot_surface( x, y, data, cmap='jet', lw=0.5, antialiased=True, shade=True, rstride=7, cstride=7 ) plt.title(title) return surf
[docs] def plot_traces(self, res, alpha=-1, index_alpha=-1): """ Plots interactive time traces. Parameters ------- res : mysvd.Results Contains the data to be plotted. alpha : float Alpha value. index_alpha : int Index of alpha value. """ if res.type == 'svd': title = 'Time traces\nblue: Raw, red: SVD' procdata = res.svddata elif res.type == 'lda': if res.method == 'tik': alpha, index_alpha = res.get_alpha(alpha, index_alpha) procdata = res.fitdata[:, :, index_alpha] title = 'Time traces\nblue: Raw, red: LDA '\ '(alpha=%.2f)' % alpha elif res.method == 'tsvd': procdata = res.fitdata title = 'Time traces\nblue: Raw, red: LDA '\ '(TSVD truncated at %i)' % res.k elif res.type == 'gf': title = 'Time traces\nblue: Raw, red: Global Fit' procdata = res.fitdata elif res.type == 'raw': title = 'Raw data' fig = plt.figure() fig.suptitle(title) ax = fig.add_subplot(111) plt.subplots_adjust(bottom=0.2) plt.plot([np.min(res.time), np.max(res.time)], [0, 0], '--', color='k') l1, = plt.plot(res.time.T, res.data[0, :], 'o-', markersize=2) if res.type != 'raw': l2, = plt.plot(res.time.T, procdata[0, :], 'o-', markersize=2) plt.xscale('log') ax.margins(x=0) axfreq = plt.axes([0.175, 0.05, 0.65, 0.03], facecolor=self.axcolor) sfreq = Slider( axfreq, res.wn_name, np.min(res.wn), np.max(res.wn), valinit=np.min(res.wn), valstep=abs(res.wn[1, 0]-res.wn[0, 0]) ) ymin = np.min(res.data[:, :]) ymax = np.max(res.data[:, :]) sc = 1.05 ax.set_ylim(ymin*sc, ymax*sc) ax.set_xlabel(res.time_name + ' / ' + res.time_unit) time_min = np.min(res.time[0, :]) time_max = np.max(res.time[0, :]) ax.set_xlim([time_min, time_max]) def update(val): val = sfreq.val ind = abs(val - res.wn).argmin() sfreq.valtext.set_text('%.2f' % (res.wn[ind, 0])) l1.set_ydata(res.data[ind, :]) if res.type != 'raw': l2.set_ydata(procdata[ind, :]) if self.yadjust_traces is True: l1data = l1.get_data()[1] maxl1 = np.max(l1data) minl1 = np.min(l1data) ax.set_ylim([minl1*1.05, maxl1*1.05]) sfreq.on_changed(update) ax_chx = plt.axes([0.012, 0.83, 0.2, 0.2], frameon=False) chxbox = CheckButtons(ax_chx, ['Adjust y-axis'], [False]) def chxboxfunc(label): self.yadjust_traces = self.chxbox_traces.get_status()[0] ax.set_ylim(self.ymin_traces, self.ymax_traces) update(sfreq.val) plt.draw() chxbox.on_clicked(chxboxfunc) self.fig_traces = fig self.ax_traces = ax self.l1_traces = l1 if res.type != 'raw': self.l2_traces = l2 self.axfreq = axfreq self.sfreq = sfreq self.chxbox_traces = chxbox self.ymin_traces = ymin*sc self.ymax_traces = ymax*sc self.yadjust_traces = False
[docs] def plot_spectra(self, res, alpha=-1, index_alpha=-1, rev=True): """ Plots interactive spectra. Parameters ------- res : mysvd.Results Contains the data to be plotted. alpha : float Alpha value. index_alpha : int Index of alpha value. rev : boolean Reverses the x-axis. """ if res.type == 'svd': title = 'Spectra\nblue: Raw, red: SVD' procdata = res.svddata elif res.type == 'lda': if res.method == 'tik': alpha, index_alpha = res.get_alpha(alpha, index_alpha) procdata = res.fitdata[:, :, index_alpha] title = 'Spectra\nblue: Raw, red: LDA '\ '(alpha=%.2f)' % alpha elif res.method == 'tsvd': procdata = res.fitdata title = 'Spectra\nblue: Raw, red: LDA '\ '(TSVD truncated at %i)' % res.k elif res.type == 'gf': title = 'Spectra\nblue: Raw, red: Global Fit' procdata = res.fitdata elif res.type == 'raw': title = 'Raw data' fig = plt.figure() fig.suptitle(title) ax = fig.add_subplot(111) plt.subplots_adjust(bottom=0.2) plt.plot([np.min(res.wn), np.max(res.wn)], [0, 0], '--', color='k') l1, = plt.plot(res.wn, res.data[:, 0], 'o-', markersize=3) if res.type != 'raw': l2, = plt.plot(res.wn, procdata[:, 0], 'o-', markersize=3) ax.margins(x=0) axtime = plt.axes([0.175, 0.05, 0.65, 0.03], facecolor=self.axcolor) time_log_dist = \ [np.log10(res.time[0, i]) - np.log10(res.time[0, i-1]) for i in range(1, res.time.shape[1])] stime = Slider( axtime, res.time_name, np.log10(np.min(res.time)), np.log10(np.max(res.time)), valinit=np.log10(np.min(res.time[0, 0])), valstep=np.min(time_log_dist) ) stime.valtext.set_text('%1.2e' % (10**stime.val)) ymin = np.min(res.data[:, :]) ymax = np.max(res.data[:, :]) sc = 1.05 ax.set_ylim(ymin*sc, ymax*sc) ax.set_xlabel(res.wn_name + ' / ' + res.wn_unit) wn_min = np.min(res.wn[:, 0]) wn_max = np.max(res.wn[:, 0]) if rev is True: ax.set_xlim([wn_max, wn_min]) else: ax.set_xlim([wn_min, wn_max]) def update(val): val = 10 ** stime.val ind = abs(val - res.time).argmin() stime.valtext.set_text('%1.2e' % (res.time[0, ind])) l1.set_ydata(res.data[:, ind]) if res.type != 'raw': l2.set_ydata(procdata[:, ind]) if self.yadjust_spectra is True: l1data = l1.get_data()[1] maxl1 = np.max(l1data) minl1 = np.min(l1data) ax.set_ylim([minl1*1.05, maxl1*1.05]) stime.on_changed(update) ax_chx = plt.axes([0.012, 0.83, 0.2, 0.2], frameon=False) chxbox = CheckButtons(ax_chx, ['Adjust y-axis'], [False]) def chxboxfunc(label): self.yadjust_spectra = self.chxbox_spectra.get_status()[0] ax.set_ylim(self.ymin_spectra, self.ymax_spectra) update(stime.val) plt.draw() chxbox.on_clicked(chxboxfunc) self.fig_spectra = fig self.ax_spectra = ax self.l1_spectra = l1 if res.type != 'raw': self.l2_spectra = l2 self.axtime = axtime self.stime = stime self.chxbox_spectra = chxbox self.ymin_spectra = ymin*sc self.ymax_spectra = ymax*sc self.yadjust_spectra = False
[docs] @staticmethod def append_ldamap(res, index_alpha=-1): """ Appends NaN values in order to expand the taus array to match the time array span. Parameters ------- res : mylda.Results Contains the results to be plotted. index_alpha : int Index of selected alpha value. Returns ------- x_k : np.array LDA map of selected alpha value. taus : np.array Extended taus array. """ if index_alpha > 0: x_k = res.x_k[:, :, index_alpha] elif index_alpha == -1: x_k = res.x_k nanarray = np.empty(np.shape(res.wn)) nanarray[:] = np.NaN nanarray = np.hstack((nanarray, nanarray)) taus = res.taus if np.min(res.time) < np.min(res.taus): taus = np.insert( taus, 0, [np.min(res.time), np.min(res.taus)*0.99] ) x_k = np.hstack((nanarray, x_k)) if np.max(res.time) > np.max(res.taus): taus = np.append(taus, [np.max(res.taus)*1.01, np.max(res.time)]) taus = taus.reshape((1, taus.size)) x_k = np.hstack((x_k, nanarray)) return x_k, taus
[docs] def plot_ldaresults(self, res): """ Plots interactive resaults of LDA. Parameters ------- res : mylda.Results Contains the results to be plotted. """ fig, axs = plt.subplots(2, 2, figsize=[6.4*2, 4.8*2]) # original data plt.sca(axs[0, 0]) self.plot_heatmap( res.data, res.time, res.wn, title='Original Data', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % (res.time_name, res.time_unit)) # plot lda data if res.method == 'tik': index_alpha = int(np.ceil(res.alphas.size/2)) ldadata = np.transpose( res.dmatrix.dot(res.x_k[:, :, index_alpha].T) ) elif res.method == 'tsvd': ldadata = np.transpose(res.dmatrix.dot(res.x_k.T)) plt.sca(axs[0, 1]) pc_ldadata = self.plot_heatmap( ldadata, res.time, res.wn, title='LDA data', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % (res.time_name, res.time_unit)) # lda map plt.sca(axs[1, 1]) if res.method == 'tik': x_k, taus = self.append_ldamap(res, index_alpha) elif res.method == 'tsvd': x_k, taus = self.append_ldamap(res) pc_map = self.plot_contourmap( x_k, taus, res.wn, title='LDA Map', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % ('tau', res.time_unit)) # lcurve if res.method == 'tik': plt.sca(axs[1, 0]) plt.plot(res.lcurve[:, 0], res.lcurve[:, 1], 'o-', markersize=2) l1, = plt.plot( res.lcurve[index_alpha, 0], res.lcurve[index_alpha, 1], 'o-', markersize=4, color='r' ) plt.title('L-curve') plt.tight_layout() plt.subplots_adjust(bottom=0.125) axalpha = plt.axes( [0.175, 0.03, 0.65, 0.02], facecolor=self.axcolor ) salpha = Slider( axalpha, 'Alpha', np.log10(np.min(res.alphas)), np.log10(np.max(res.alphas)), valinit=np.log10(np.min(res.alphas[index_alpha])), valstep=abs(np.log10(res.alphas[1]) - np.log10(res.alphas[0])) ) salpha.valtext.set_text('%1.2e' % (10**salpha.val)) def update(val): val = 10 ** salpha.val ind = abs(val - res.alphas).argmin() salpha.valtext.set_text('%1.2e' % (res.alphas[ind])) # lda map plt.sca(axs[1, 1]) plt.cla() x_k, _ = self.append_ldamap(res, ind) self.plot_contourmap( x_k, taus, res.wn, title='LDA Map', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % ('tau', res.time_unit)) # lda data ldadata = np.transpose(res.dmatrix.dot(res.x_k[:, :, ind].T)) pc_ldadata.set_array(ldadata.ravel()) # lcurve l1.set_data(res.lcurve[ind, 0], res.lcurve[ind, 1]) salpha.on_changed(update) self.l1_lda = l1 self.axalpha = axalpha self.salpha = salpha self.fig_lda = fig self.axs_lda = axs self.map_lda = pc_map self.ldadata = pc_ldadata
[docs] def plot_solutionvector(self, res, alpha=-1, index_alpha=-1): """ Plots interactive solution vector. Parameters ------- res : mylda.Results Contains the data to be plotted. alpha : float Plot for the closest alpha as specified. index_alpha : int Plot for specified alpha at index. """ x_k, title = res.get_xk(alpha, index_alpha) fig = plt.figure() fig.suptitle('Solution vector ' + title[8:]) ax = fig.add_subplot(111) plt.subplots_adjust(bottom=0.2) plt.plot([np.min(res.taus), np.max(res.taus)], [0, 0], '--', color='k') l1, = plt.plot( res.taus.T, np.sum(np.abs(x_k), axis=0), 'o-', markersize=3 ) plt.xscale('log') ax.margins(x=0) axsvec = plt.axes([0.175, 0.05, 0.65, 0.03], facecolor=self.axcolor) ssvec = Slider( axsvec, 'Frequency', np.min(res.wn), np.max(res.wn), valinit=np.min(res.wn), valstep=abs(res.wn[1, 0]-res.wn[0, 0]) ) l1data = l1.get_data()[1] ymin = np.min(l1data) ymax = np.max(l1data) ax.set_ylim(ymin*0.95, ymax*1.05) ax.set_xlabel('time constant / s') time_min = np.min(res.taus[0, :]) time_max = np.max(res.taus[0, :]) ax.set_xlim([time_min, time_max]) def update(val): val = ssvec.val ind = abs(val - res.wn).argmin() ssvec.valtext.set_text('%.2f' % (res.wn[ind, 0])) l1.set_ydata(x_k[ind, :]) if self.chxbox_svec.get_status()[1] is True: l1data = l1.get_data()[1] maxl1 = np.max(l1data) minl1 = np.min(l1data) ax.set_ylim([minl1*1.05, maxl1*1.05]) ssvec.on_changed(update) axsvec.set_visible(False) ax_chx = plt.axes([0.012, 0.83, 0.2, 0.2], frameon=False) chxbox = CheckButtons( ax_chx, ['Show sum', 'Adjust y-axis'], [True, False] ) def chxboxfunc(label): if self.chxbox_svec.get_status()[0] is True: self.axsvec.set_visible(False) l1.set_ydata(np.sum(np.abs(x_k), axis=0)) ax.set_ylim(self.ymin_svec, self.ymax_svec) else: self.axsvec.set_visible(True) update(ssvec.val) ymin = np.min(x_k.flatten()) ymax = np.max(x_k.flatten()) ax.set_ylim([ymin, ymax]) plt.draw() chxbox.on_clicked(chxboxfunc) self.fig_svec = fig self.ax_svec = ax self.axsvec = axsvec self.ssvec = ssvec self.chxbox_svec = chxbox self.ymin_svec = ymin*0.95 self.ymax_svec = ymax*1.05
[docs]class MidpointNormalize(colors.Normalize): """ Class for setting 0 as midpoint in the colormap. Attributes ---------- midpoint : float Midpoint of colormap """ def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): self.midpoint = midpoint colors.Normalize.__init__(self, vmin, vmax, clip) def __call__(self, value, clip=None): x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] return np.ma.masked_array(np.interp(value, x, y))