Commit 9630ca8b authored by René Fritze's avatar René Fritze Committed by René Fritze

[vis] refactor jupyter display with MPL

parent 79ef3c3f
......@@ -4,7 +4,8 @@
import itertools
import numpy as np
from ipywidgets import HTML, HBox
from IPython.core.display import display
from ipywidgets import HTML, HBox, widgets, Layout
import matplotlib.pyplot as plt
from pymor.core.config import config
......@@ -26,6 +27,67 @@ class concat_display(object):
return '\n\n'.join(repr(a) for a in self.args)
class MPLPlotBase:
def __init__(self, U, grid, codim, legend, bounding_box=None, separate_colorbars=False, count=None,
separate_plots=False):
assert isinstance(U, VectorArray) \
or (isinstance(U, tuple)
and all(isinstance(u, VectorArray) for u in U)
and all(len(u) == len(U[0]) for u in U))
self.fig_ids = (U.uid,) if isinstance(U, VectorArray) else tuple(u.uid for u in U)
U = (U.to_numpy().astype(np.float64, copy=False),) if isinstance(U, VectorArray) else \
tuple(u.to_numpy().astype(np.float64, copy=False) for u in U)
if not config.HAVE_MATPLOTLIB:
raise ImportError('cannot visualize: import of matplotlib failed')
if not config.HAVE_IPYWIDGETS and len(U[0]) > 1:
raise ImportError('cannot visualize: import of ipywidgets failed')
self.legend = (legend,) if isinstance(legend, str) else legend
assert legend is None or isinstance(legend, tuple) and len(legend) == len(U)
self._set_limits(U)
self.plots = []
# this _supposed_ to let animations run in sync
sync_timer = None
do_animation = len(U[0]) > 1
for i, (vmin, vmax, u) in enumerate(zip(self.vmins, self.vmaxs, U)):
figure = plt.figure(self.fig_ids[i])
ax = plt.axes()
sync_timer = sync_timer or figure.canvas.new_timer()
if grid.dim == 2:
self.plots.append(MatplotlibPatchAxes(U=u, ax=ax, figure=figure, sync_timer=sync_timer, grid=grid,
bounding_box=bounding_box, vmin=vmin, vmax=vmax,
codim=codim, colorbar=separate_colorbars or i == len(U) - 1))
else:
assert count
self.plots.append(Matplotlib1DAxes(U=u, ax=ax, figure=figure, sync_timer=sync_timer, grid=grid,
vmin=vmin, vmax=vmax, count=count, codim=codim,
separate_plots=separate_plots))
if self.legend:
ax.set_title(self.legend[i])
plt.tight_layout()
if do_animation:
for fig_id in self.fig_ids:
# avoids figure double display
plt.close(fig_id)
self._cd = concat_display(*[p.html for p in self.plots])
# IPython display system checks for presence and calls this func
self._repr_html_ = self._cd._repr_html_
else:
self._out = widgets.Output()
with self._out:
plt.show()
# IPython display system checks for presence and calls this func
self._ipython_display_ = self._out._ipython_display_
def visualize_patch(grid, U, bounding_box=([0, 0], [1, 1]), codim=2, title=None, legend=None,
separate_colorbars=False, rescale_colorbars=False, columns=2):
"""Visualize scalar data associated to a two-dimensional |Grid| as a patch plot.
......@@ -60,82 +122,43 @@ def visualize_patch(grid, U, bounding_box=([0, 0], [1, 1]), codim=2, title=None,
at the same time.
"""
assert isinstance(U, VectorArray) \
or (isinstance(U, tuple)
and all(isinstance(u, VectorArray) for u in U)
and all(len(u) == len(U[0]) for u in U))
U = (U.to_numpy().astype(np.float64, copy=False),) if isinstance(U, VectorArray) else \
tuple(u.to_numpy().astype(np.float64, copy=False) for u in U)
if not config.HAVE_MATPLOTLIB:
raise ImportError('cannot visualize: import of matplotlib failed')
if not config.HAVE_IPYWIDGETS and len(U[0]) > 1:
raise ImportError('cannot visualize: import of ipywidgets failed')
if isinstance(legend, str):
legend = (legend,)
assert legend is None or isinstance(legend, tuple) and len(legend) == len(U)
class Plot:
class Plot(MPLPlotBase):
def __init__(self):
def _set_limits(self, np_U):
if separate_colorbars:
# todo rescaling not set up
if rescale_colorbars:
self.vmins = tuple(np.min(u[0]) for u in U)
self.vmaxs = tuple(np.max(u[0]) for u in U)
self.vmins = tuple(np.min(u[0]) for u in np_U)
self.vmaxs = tuple(np.max(u[0]) for u in np_U)
else:
self.vmins = tuple(np.min(u) for u in U)
self.vmaxs = tuple(np.max(u) for u in U)
self.vmins = tuple(np.min(u) for u in np_U)
self.vmaxs = tuple(np.max(u) for u in np_U)
else:
if rescale_colorbars:
self.vmins = (min(np.min(u[0]) for u in U),) * len(U)
self.vmaxs = (max(np.max(u[0]) for u in U),) * len(U)
self.vmins = (min(np.min(u[0]) for u in np_U),) * len(np_U)
self.vmaxs = (max(np.max(u[0]) for u in np_U),) * len(np_U)
else:
self.vmins = (min(np.min(u) for u in U),) * len(U)
self.vmaxs = (max(np.max(u) for u in U),) * len(U)
self.plots = plots = []
# this _supposed_ to let animations run in sync
sync_timer = None
for i, (vmin, vmax, u) in enumerate(zip(self.vmins, self.vmaxs, U)):
figure = plt.figure(i)
ax = plt.axes()
sync_timer = sync_timer or figure.canvas.new_timer()
plots.append(MatplotlibPatchAxes(U=u, ax=ax, figure=figure, sync_timer=sync_timer, grid=grid,
bounding_box=bounding_box, vmin=vmin, vmax=vmax,
codim=codim, colorbar=separate_colorbars or i == len(U)-1))
if legend:
ax.set_title(legend[i])
plt.tight_layout()
plt.close(figure)
def set(self, U, ind):
if rescale_colorbars:
if separate_colorbars:
self.vmins = tuple(np.min(u[ind]) for u in U)
self.vmaxs = tuple(np.max(u[ind]) for u in U)
self.vmins = (min(np.min(u) for u in np_U),) * len(np_U)
self.vmaxs = (max(np.max(u) for u in np_U),) * len(np_U)
def __init__(self):
super(Plot, self).__init__(U, grid, codim, legend, bounding_box=bounding_box,
separate_colorbars=separate_colorbars)
def set(self, ind):
np_U = self.U
if self.rescale_colorbars:
if self.separate_colorbars:
self.vmins = tuple(np.min(u[ind]) for u in np_U)
self.vmaxs = tuple(np.max(u[ind]) for u in np_U)
else:
self.vmins = (min(np.min(u[ind]) for u in U),) * len(U)
self.vmaxs = (max(np.max(u[ind]) for u in U),) * len(U)
self.vmins = (min(np.min(u[ind]) for u in np_U),) * len(np_U)
self.vmaxs = (max(np.max(u[ind]) for u in np_U),) * len(np_U)
for u, plot, vmin, vmax in zip(U, self.plots, self.vmins, self.vmaxs):
for u, plot, vmin, vmax in zip(np_U, self.plots, self.vmins, self.vmaxs):
plot.set(u[ind], vmin=vmin, vmax=vmax)
plot = Plot()
if len(U[0]) > 1:
return concat_display(*[p.html for p in plot.plots])
return plot
# otherwise the subplot displays twice
plt.close(plot.figure)
return concat_display(*[p.html for p in plot.plots])
return plot
return Plot()
......@@ -168,69 +191,31 @@ def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_
column
Number of columns the subplots are organized in.
"""
assert isinstance(U, VectorArray) \
or (isinstance(U, tuple)
and all(isinstance(u, VectorArray) for u in U)
and all(len(u) == len(U[0]) for u in U))
U = (U.to_numpy(),) if isinstance(U, VectorArray) else tuple(u.to_numpy() for u in U)
if not config.HAVE_MATPLOTLIB:
raise ImportError('cannot visualize: import of matplotlib failed')
if not config.HAVE_IPYWIDGETS and len(U[0]) > 1:
raise ImportError('cannot visualize: import of ipywidgets failed')
if isinstance(legend, str):
legend = (legend,)
assert legend is None or isinstance(legend, tuple) and len(legend) == len(U)
class Plot:
class Plot(MPLPlotBase):
def __init__(self):
def _set_limits(self, np_U):
if separate_plots:
if separate_axes:
self.vmins = tuple(np.min(u) for u in U[0])
self.vmaxs = tuple(np.max(u) for u in U[0])
self.vmins = tuple(np.min(u) for u in np_U[0])
self.vmaxs = tuple(np.max(u) for u in np_U[0])
else:
self.vmins = (min(np.min(u) for u in U),) * len(U[0])
self.vmaxs = (max(np.max(u) for u in U),) * len(U[0])
self.vmins = (min(np.min(u) for u in np_U),) * len(np_U[0])
self.vmaxs = (max(np.max(u) for u in np_U),) * len(np_U[0])
else:
self.vmins = (min(np.min(u) for u in U[0]),)
self.vmaxs = (max(np.max(u) for u in U[0]),)
import matplotlib.pyplot as plt
self.vmins = (min(np.min(u) for u in np_U[0]),)
self.vmaxs = (max(np.max(u) for u in np_U[0]),)
if separate_axes:
rows = int(np.ceil(len(U[0]) / columns))
else:
rows = int(np.ceil(len(U) / columns))
self.plots = []
# this _supposed_ to let animations run in sync
sync_timer = None
for i, (vmin, vmax, u) in enumerate(zip(self.vmins, self.vmaxs, U)):
count = 1
if not separate_plots:
count = len(U[0])
figure = plt.figure(i)
ax = plt.axes()
sync_timer = sync_timer or figure.canvas.new_timer()
self.plots.append(Matplotlib1DAxes(U=u, ax=ax, figure=figure, sync_timer=sync_timer, grid=grid,
vmin=vmin, vmax=vmax, count=count, codim=codim, ))
if legend:
ax.set_title(legend[i])
plt.tight_layout()
plt.close(figure)
def __init__(self):
count = 1
super(Plot, self).__init__(U, grid, codim, legend, separate_plots=separate_plots, count=count)
def set(self, U):
def set(self, ind):
np_U = self.U[ind]
if separate_plots:
for u, plot, vmin, vmax in zip(U, self.plots, self.vmins, self.vmaxs):
for u, plot, vmin, vmax in zip(np_U, self.plots, self.vmins, self.vmaxs):
plot.set(u[np.newaxis, ...], vmin=vmin, vmax=vmax)
else:
self.plots[0].set(U, vmin=self.vmins[0], vmax=self.vmaxs[0])
plot = Plot()
self.plots[0].set(np_U, vmin=self.vmins[0], vmax=self.vmaxs[0])
if len(U[0]) > 1:
\ No newline at end of file
return Plot()
......@@ -43,6 +43,8 @@ class MatplotlibAxesBase:
blit=True, event_source=sync_timer)
# generating the HTML instance outside this class causes the plot display to fail
self.html = HTML(self.anim.to_jshtml())
else:
self.set(self.U[0])
@abstractmethod
def _plot_init(self):
......@@ -101,11 +103,13 @@ class MatplotlibPatchAxes(MatplotlibAxesBase):
class Matplotlib1DAxes(MatplotlibAxesBase):
def __init__(self, U, ax, figure, grid, count=1, vmin=None, vmax=None, codim=1, sync_timer=None):
def __init__(self, U, ax, figure, grid, count=1, vmin=None, vmax=None, codim=1, separate_plots=False,
sync_timer=None):
assert isinstance(grid, OnedGrid)
assert codim in (0, 1)
self.count = count
self.separate_plots = separate_plots
super().__init__(U=U, ax=ax, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
sync_timer=sync_timer)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment