Source code for trtoolbox.plothelper

# TODO: plot raw data stuff

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from matplotlib.widgets import CheckButtons
import matplotlib.colors as colors
from mpl_toolkits import mplot3d


[docs]class PlotHelper: """ Object for interactive plotting. This class ensures that all the matplotlib objects are kept referenced which ensures proper function of the sliders. Attributes ---------- fig_traces : matplotlib.figure.Figure ax_traces : matplotlib.axes.Axes l1_traces : matplotlib.lines.Line2D l2_traces: matplotlib.lines.Line2D axfreq : matplotlib.axes.Axes sfreq : matplotlib.widgets.Slider fig_spectra : matplotlib.figure.Figure ax_spectra : matplotlib.axes.Axes l1_spectra : matplotlib.lines.Line2D l2_spectra : matplotlib.lines.Line2D axtime : matplotlib.axes.Axes stime : matplotlib.widgets.Slider fig_lda : matplotlib.figure.Figure() axs_lda : np.array map_lda : matplotlib.contour.QuadContourSet ldadata : matplotlib.collections.QuadMesh l1_lda : matplotlib.lines.Line2D axalpha : matplotlib.axes.Axes salpha : matplotlib.widgets.Slider axcolor : str """ def __init__(self): self.fig_traces = matplotlib.figure.Figure() self.ax_traces = matplotlib.axes.Axes(self.fig_traces, [0, 0, 0, 0]) self.l1_traces = matplotlib.lines.Line2D([], []) self.l2_traces = matplotlib.lines.Line2D([], []) self.axfreq = matplotlib.axes.Axes(self.fig_traces, [0, 0, 0, 0]) self.sfreq = matplotlib.widgets.Slider(self.axfreq, '', 0, 1) self.fig_spectra = matplotlib.figure.Figure() self.ax_spectra = matplotlib.axes.Axes(self.fig_spectra, [0, 0, 0, 0]) self.l1_spectra = matplotlib.lines.Line2D([], []) self.l2_spectra = matplotlib.lines.Line2D([], []) self.axtime = matplotlib.axes.Axes(self.fig_spectra, [0, 0, 0, 0]) self.stime = matplotlib.widgets.Slider(self.axtime, '', 0, 1) self.fig_lda = matplotlib.figure.Figure() self.axs_lda = np.ndarray((2, 2)) # self.map_lda = matplotlib.contour.QuadContourSet( # matplotlib.axes.Axes(self.fig_lda, [0, 0, 0, 0]) # ) # self.ldadata = matplotlib.collections.QuadMesh() self.l1_lda = matplotlib.lines.Line2D([], []) self.axalpha = matplotlib.axes.Axes(self.fig_lda, [0, 0, 0, 0]) self.salpha = matplotlib.widgets.Slider(self.axalpha, '', 0, 1) self.axcolor = 'lightgoldenrodyellow'
[docs] @staticmethod def plot_heatmap(data, time, wn, title='data', newfig=True): """ Plots a nice looking heatmap. Parameters ---------- data : np.array Data matrix subjected to SVD. Assuming *m x n* with m as frequency and n as time. But it is actually not important. time : np.array Time array. wn : np.array Frequency array. title : np.array Title of plot. Default *data*. newfig : boolean Setting to False prevents the creation of a new figure. """ if newfig is True: plt.figure() if data.shape[1] != time.size: data = np.transpose(data) if time.size == 0 or wn.size == 0: pc = plt.pcolormesh(data, cmap='jet', shading='gouraud') else: pc = plt.pcolormesh( time, wn, data, cmap='jet', shading='gouraud', norm=MidpointNormalize(midpoint=0), ) plt.xscale('log') plt.title(title) return pc
[docs] @staticmethod def plot_contourmap(data, time, wn, title='data', newfig=True): """ Plots a nice looking contourmap. Parameters ---------- data : np.array Data matrix subjected to SVD. Assuming *m x n* with m as frequency and n as time. But it is actually not important. time : np.array Time array. wn : np.array Frequency array. title : np.array Title of plot. Default *data*. newfig : boolean Setting to False prevents the creation of a new figure. """ if newfig is True: plt.figure() # ensuring that time spans columns if data.shape[1] != time.size: data = np.transpose(data) levels = 10 if time.size == 0 or wn.size == 0: pc = plt.contourf( data, levels=levels, cmap='bwr', norm=MidpointNormalize(midpoint=0) ) pc = plt.contour( data, levels=levels, linewidths=0.3, colors='k' ) else: pc = plt.contourf( time.flatten(), wn.flatten(), data, levels=levels, cmap='bwr', norm=MidpointNormalize(midpoint=0) ) pc = plt.contour( time.flatten(), wn.flatten(), data, levels=levels, linewidths=0.1, colors='k' ) plt.xscale('log') plt.title(title) return pc
# TODO: better 3D plot
[docs] @staticmethod def plot_surface(data, time, wn, title='data'): """ Plots a nice looking heatmap. Parameters ---------- data : np.array Data matrix subjected to SVD. Assuming *m x n* with m as frequency and n as time. But it is actually not important. time : np.array Time array. wn : np.array Frequency array. title : np.array Title of plot. Default *data*. """ plt.figure() ax = plt.axes(projection="3d") # ensuring that time spans columns if data.shape[1] != time.size: data = np.transpose(data) if time.size == 0 or wn.size == 0: surf = plt.pcolormesh(data, cmap='jet', shading='gouraud') plt.xscale('log') else: x, y = np.meshgrid(np.log10(time), wn) surf = ax.plot_surface( x, y, data, cmap='jet', lw=0.5, antialiased=True, shade=True, rstride=7, cstride=7 ) plt.title(title) return surf
[docs] def plot_traces(self, res, alpha=-1, index_alpha=-1): """ Plots interactive time traces. Parameters ------- res : mysvd.Results Contains the data to be plotted. alpha : float Alpha value. index_alpha : int Index of alpha value. """ if res.type == 'svd': title = 'Time traces\nblue: Raw, red: SVD' procdata = res.svddata elif res.type == 'lda': if res.method == 'tik': alpha, index_alpha = res.get_alpha(alpha, index_alpha) procdata = res.fitdata[:, :, index_alpha] title = 'Time traces\nblue: Raw, red: LDA '\ '(alpha=%.2f)' % alpha elif res.method == 'tsvd': procdata = res.fitdata title = 'Time traces\nblue: Raw, red: LDA '\ '(TSVD truncated at %i)' % res.k elif res.type == 'gf': title = 'Time traces\nblue: Raw, red: Global Fit' procdata = res.fitdata elif res.type == 'raw': title = 'Raw data' fig = plt.figure() fig.suptitle(title) ax = fig.add_subplot(111) plt.subplots_adjust(bottom=0.2) plt.plot([np.min(res.time), np.max(res.time)], [0, 0], '--', color='k') l1, = plt.plot(res.time.T, res.data[0, :], 'o-', markersize=2) if res.type != 'raw': l2, = plt.plot(res.time.T, procdata[0, :], 'o-', markersize=2) plt.xscale('log') ax.margins(x=0) axfreq = plt.axes([0.175, 0.05, 0.65, 0.03], facecolor=self.axcolor) sfreq = Slider( axfreq, res.wn_name, np.min(res.wn), np.max(res.wn), valinit=np.min(res.wn), valstep=abs(res.wn[1, 0]-res.wn[0, 0]) ) ymin = np.min(res.data[:, :]) ymax = np.max(res.data[:, :]) sc = 1.05 ax.set_ylim(ymin*sc, ymax*sc) ax.set_xlabel(res.time_name + ' / ' + res.time_unit) time_min = np.min(res.time[0, :]) time_max = np.max(res.time[0, :]) ax.set_xlim([time_min, time_max]) def update(val): val = sfreq.val ind = abs(val - res.wn).argmin() sfreq.valtext.set_text('%.2f' % (res.wn[ind, 0])) l1.set_ydata(res.data[ind, :]) if res.type != 'raw': l2.set_ydata(procdata[ind, :]) if self.yadjust_traces is True: l1data = l1.get_data()[1] maxl1 = np.max(l1data) minl1 = np.min(l1data) ax.set_ylim([minl1*1.05, maxl1*1.05]) sfreq.on_changed(update) ax_chx = plt.axes([0.012, 0.83, 0.2, 0.2], frameon=False) chxbox = CheckButtons(ax_chx, ['Adjust y-axis'], [False]) def chxboxfunc(label): self.yadjust_traces = self.chxbox_traces.get_status()[0] ax.set_ylim(self.ymin_traces, self.ymax_traces) update(sfreq.val) plt.draw() chxbox.on_clicked(chxboxfunc) self.fig_traces = fig self.ax_traces = ax self.l1_traces = l1 if res.type != 'raw': self.l2_traces = l2 self.axfreq = axfreq self.sfreq = sfreq self.chxbox_traces = chxbox self.ymin_traces = ymin*sc self.ymax_traces = ymax*sc self.yadjust_traces = False
[docs] def plot_spectra(self, res, alpha=-1, index_alpha=-1, rev=True): """ Plots interactive spectra. Parameters ------- res : mysvd.Results Contains the data to be plotted. alpha : float Alpha value. index_alpha : int Index of alpha value. rev : boolean Reverses the x-axis. """ if res.type == 'svd': title = 'Spectra\nblue: Raw, red: SVD' procdata = res.svddata elif res.type == 'lda': if res.method == 'tik': alpha, index_alpha = res.get_alpha(alpha, index_alpha) procdata = res.fitdata[:, :, index_alpha] title = 'Spectra\nblue: Raw, red: LDA '\ '(alpha=%.2f)' % alpha elif res.method == 'tsvd': procdata = res.fitdata title = 'Spectra\nblue: Raw, red: LDA '\ '(TSVD truncated at %i)' % res.k elif res.type == 'gf': title = 'Spectra\nblue: Raw, red: Global Fit' procdata = res.fitdata elif res.type == 'raw': title = 'Raw data' fig = plt.figure() fig.suptitle(title) ax = fig.add_subplot(111) plt.subplots_adjust(bottom=0.2) plt.plot([np.min(res.wn), np.max(res.wn)], [0, 0], '--', color='k') l1, = plt.plot(res.wn, res.data[:, 0], 'o-', markersize=3) if res.type != 'raw': l2, = plt.plot(res.wn, procdata[:, 0], 'o-', markersize=3) ax.margins(x=0) axtime = plt.axes([0.175, 0.05, 0.65, 0.03], facecolor=self.axcolor) time_log_dist = \ [np.log10(res.time[0, i]) - np.log10(res.time[0, i-1]) for i in range(1, res.time.shape[1])] stime = Slider( axtime, res.time_name, np.log10(np.min(res.time)), np.log10(np.max(res.time)), valinit=np.log10(np.min(res.time[0, 0])), valstep=np.min(time_log_dist) ) stime.valtext.set_text('%1.2e' % (10**stime.val)) ymin = np.min(res.data[:, :]) ymax = np.max(res.data[:, :]) sc = 1.05 ax.set_ylim(ymin*sc, ymax*sc) ax.set_xlabel(res.wn_name + ' / ' + res.wn_unit) wn_min = np.min(res.wn[:, 0]) wn_max = np.max(res.wn[:, 0]) if rev is True: ax.set_xlim([wn_max, wn_min]) else: ax.set_xlim([wn_min, wn_max]) def update(val): val = 10 ** stime.val ind = abs(val - res.time).argmin() stime.valtext.set_text('%1.2e' % (res.time[0, ind])) l1.set_ydata(res.data[:, ind]) if res.type != 'raw': l2.set_ydata(procdata[:, ind]) if self.yadjust_spectra is True: l1data = l1.get_data()[1] maxl1 = np.max(l1data) minl1 = np.min(l1data) ax.set_ylim([minl1*1.05, maxl1*1.05]) stime.on_changed(update) ax_chx = plt.axes([0.012, 0.83, 0.2, 0.2], frameon=False) chxbox = CheckButtons(ax_chx, ['Adjust y-axis'], [False]) def chxboxfunc(label): self.yadjust_spectra = self.chxbox_spectra.get_status()[0] ax.set_ylim(self.ymin_spectra, self.ymax_spectra) update(stime.val) plt.draw() chxbox.on_clicked(chxboxfunc) self.fig_spectra = fig self.ax_spectra = ax self.l1_spectra = l1 if res.type != 'raw': self.l2_spectra = l2 self.axtime = axtime self.stime = stime self.chxbox_spectra = chxbox self.ymin_spectra = ymin*sc self.ymax_spectra = ymax*sc self.yadjust_spectra = False
[docs] @staticmethod def append_ldamap(res, index_alpha=-1): """ Appends NaN values in order to expand the taus array to match the time array span. Parameters ------- res : mylda.Results Contains the results to be plotted. index_alpha : int Index of selected alpha value. Returns ------- x_k : np.array LDA map of selected alpha value. taus : np.array Extended taus array. """ if index_alpha > 0: x_k = res.x_k[:, :, index_alpha] elif index_alpha == -1: x_k = res.x_k nanarray = np.empty(np.shape(res.wn)) nanarray[:] = np.NaN nanarray = np.hstack((nanarray, nanarray)) taus = res.taus if np.min(res.time) < np.min(res.taus): taus = np.insert( taus, 0, [np.min(res.time), np.min(res.taus)*0.99] ) x_k = np.hstack((nanarray, x_k)) if np.max(res.time) > np.max(res.taus): taus = np.append(taus, [np.max(res.taus)*1.01, np.max(res.time)]) taus = taus.reshape((1, taus.size)) x_k = np.hstack((x_k, nanarray)) return x_k, taus
[docs] def plot_ldaresults(self, res): """ Plots interactive resaults of LDA. Parameters ------- res : mylda.Results Contains the results to be plotted. """ fig, axs = plt.subplots(2, 2, figsize=[6.4*2, 4.8*2]) # original data plt.sca(axs[0, 0]) self.plot_heatmap( res.data, res.time, res.wn, title='Original Data', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % (res.time_name, res.time_unit)) # plot lda data if res.method == 'tik': index_alpha = int(np.ceil(res.alphas.size/2)) ldadata = np.transpose( res.dmatrix.dot(res.x_k[:, :, index_alpha].T) ) elif res.method == 'tsvd': ldadata = np.transpose(res.dmatrix.dot(res.x_k.T)) plt.sca(axs[0, 1]) pc_ldadata = self.plot_heatmap( ldadata, res.time, res.wn, title='LDA data', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % (res.time_name, res.time_unit)) # lda map plt.sca(axs[1, 1]) if res.method == 'tik': x_k, taus = self.append_ldamap(res, index_alpha) elif res.method == 'tsvd': x_k, taus = self.append_ldamap(res) pc_map = self.plot_contourmap( x_k, taus, res.wn, title='LDA Map', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % ('tau', res.time_unit)) # lcurve if res.method == 'tik': plt.sca(axs[1, 0]) plt.plot(res.lcurve[:, 0], res.lcurve[:, 1], 'o-', markersize=2) l1, = plt.plot( res.lcurve[index_alpha, 0], res.lcurve[index_alpha, 1], 'o-', markersize=4, color='r' ) plt.title('L-curve') plt.tight_layout() plt.subplots_adjust(bottom=0.125) axalpha = plt.axes( [0.175, 0.03, 0.65, 0.02], facecolor=self.axcolor ) salpha = Slider( axalpha, 'Alpha', np.log10(np.min(res.alphas)), np.log10(np.max(res.alphas)), valinit=np.log10(np.min(res.alphas[index_alpha])), valstep=abs(np.log10(res.alphas[1]) - np.log10(res.alphas[0])) ) salpha.valtext.set_text('%1.2e' % (10**salpha.val)) def update(val): val = 10 ** salpha.val ind = abs(val - res.alphas).argmin() salpha.valtext.set_text('%1.2e' % (res.alphas[ind])) # lda map plt.sca(axs[1, 1]) plt.cla() x_k, _ = self.append_ldamap(res, ind) self.plot_contourmap( x_k, taus, res.wn, title='LDA Map', newfig=False) plt.ylabel('%s / %s' % (res.wn_name, res.wn_unit)) plt.xlabel('%s / %s' % ('tau', res.time_unit)) # lda data ldadata = np.transpose(res.dmatrix.dot(res.x_k[:, :, ind].T)) pc_ldadata.set_array(ldadata.ravel()) # lcurve l1.set_data(res.lcurve[ind, 0], res.lcurve[ind, 1]) salpha.on_changed(update) self.l1_lda = l1 self.axalpha = axalpha self.salpha = salpha self.fig_lda = fig self.axs_lda = axs self.map_lda = pc_map self.ldadata = pc_ldadata
[docs] def plot_solutionvector(self, res, alpha=-1, index_alpha=-1): """ Plots interactive solution vector. Parameters ------- res : mylda.Results Contains the data to be plotted. alpha : float Plot for the closest alpha as specified. index_alpha : int Plot for specified alpha at index. """ x_k, title = res.get_xk(alpha, index_alpha) fig = plt.figure() fig.suptitle('Solution vector ' + title[8:]) ax = fig.add_subplot(111) plt.subplots_adjust(bottom=0.2) plt.plot([np.min(res.taus), np.max(res.taus)], [0, 0], '--', color='k') l1, = plt.plot( res.taus.T, np.sum(np.abs(x_k), axis=0), 'o-', markersize=3 ) plt.xscale('log') ax.margins(x=0) axsvec = plt.axes([0.175, 0.05, 0.65, 0.03], facecolor=self.axcolor) ssvec = Slider( axsvec, 'Frequency', np.min(res.wn), np.max(res.wn), valinit=np.min(res.wn), valstep=abs(res.wn[1, 0]-res.wn[0, 0]) ) l1data = l1.get_data()[1] ymin = np.min(l1data) ymax = np.max(l1data) ax.set_ylim(ymin*0.95, ymax*1.05) ax.set_xlabel('time constant / s') time_min = np.min(res.taus[0, :]) time_max = np.max(res.taus[0, :]) ax.set_xlim([time_min, time_max]) def update(val): val = ssvec.val ind = abs(val - res.wn).argmin() ssvec.valtext.set_text('%.2f' % (res.wn[ind, 0])) l1.set_ydata(x_k[ind, :]) if self.chxbox_svec.get_status()[1] is True: l1data = l1.get_data()[1] maxl1 = np.max(l1data) minl1 = np.min(l1data) ax.set_ylim([minl1*1.05, maxl1*1.05]) ssvec.on_changed(update) axsvec.set_visible(False) ax_chx = plt.axes([0.012, 0.83, 0.2, 0.2], frameon=False) chxbox = CheckButtons( ax_chx, ['Show sum', 'Adjust y-axis'], [True, False] ) def chxboxfunc(label): if self.chxbox_svec.get_status()[0] is True: self.axsvec.set_visible(False) l1.set_ydata(np.sum(np.abs(x_k), axis=0)) ax.set_ylim(self.ymin_svec, self.ymax_svec) else: self.axsvec.set_visible(True) update(ssvec.val) ymin = np.min(x_k.flatten()) ymax = np.max(x_k.flatten()) ax.set_ylim([ymin, ymax]) plt.draw() chxbox.on_clicked(chxboxfunc) self.fig_svec = fig self.ax_svec = ax self.axsvec = axsvec self.ssvec = ssvec self.chxbox_svec = chxbox self.ymin_svec = ymin*0.95 self.ymax_svec = ymax*1.05
[docs]class MidpointNormalize(colors.Normalize): """ Class for setting 0 as midpoint in the colormap. Attributes ---------- midpoint : float Midpoint of colormap """ def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False): self.midpoint = midpoint colors.Normalize.__init__(self, vmin, vmax, clip) def __call__(self, value, clip=None): x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1] return np.ma.masked_array(np.interp(value, x, y))