Commit 9242f951 authored by René Fritze's avatar René Fritze Committed by René Fritze

[vis] fix for 1D MPL separate plots/axes

parent 43ae7a56
......@@ -2,6 +2,7 @@
# Copyright 2013-2020 pyMOR developers and contributors. All rights reserved.
# License: BSD 2-Clause License (http://opensource.org/licenses/BSD-2-Clause)
import itertools
from pprint import pprint
import numpy as np
from IPython.core.display import display
......@@ -15,14 +16,18 @@ from pymor.vectorarrays.interface import VectorArray
class MPLPlotBase:
def __init__(self, U, grid, codim, legend, bounding_box=None, separate_colorbars=False, count=None,
separate_plots=False):
def __init__(self, U, grid, codim, legend, bounding_box=None, separate_colorbars=False, columns=2,
separate_plots=False, separate_axes=False):
assert isinstance(U, VectorArray) \
or (isinstance(U, tuple)
and all(isinstance(u, VectorArray) for u in U)
and all(len(u) == len(U[0]) for u in U))
self.fig_ids = (U.uid,) if isinstance(U, VectorArray) else tuple(u.uid for u in U)
U = (U.to_numpy().astype(np.float64, copy=False),) if isinstance(U, VectorArray) else \
if separate_plots:
self.fig_ids = (U.uid,) if isinstance(U, VectorArray) else tuple(u.uid for u in U)
else:
# using the same id multiple times lets us automagically re-use the same figure
self.fig_ids = (U.uid,) if isinstance(U, VectorArray) else [U[0].uid] * len(U)
self.U = U = (U.to_numpy().astype(np.float64, copy=False),) if isinstance(U, VectorArray) else \
tuple(u.to_numpy().astype(np.float64, copy=False) for u in U)
if not config.HAVE_MATPLOTLIB:
......@@ -37,25 +42,36 @@ class MPLPlotBase:
# this _supposed_ to let animations run in sync
sync_timer = None
do_animation = len(U[0]) > 1
do_animation = not separate_axes and len(U[0]) > 1
for i, (vmin, vmax, u) in enumerate(zip(self.vmins, self.vmaxs, U)):
figure = plt.figure(self.fig_ids[i])
ax = plt.axes()
if separate_plots:
for i, (vmin, vmax, u) in enumerate(zip(self.vmins, self.vmaxs, U)):
figure = plt.figure(self.fig_ids[i])
sync_timer = sync_timer or figure.canvas.new_timer()
if grid.dim == 2:
plot = MatplotlibPatchAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid, vmin=vmin, vmax=vmax,
bounding_box=bounding_box, codim=codim, columns=columns,
colorbar=separate_colorbars or i == len(U) - 1)
else:
plot = Matplotlib1DAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid, vmin=vmin, vmax=vmax,
columns=columns, codim=codim, separate_axes=separate_axes)
if self.legend:
plot.ax[0].set_title(self.legend[i])
self.plots.append(plot)
# plt.tight_layout()
else:
figure = plt.figure(self.fig_ids[0])
sync_timer = sync_timer or figure.canvas.new_timer()
if grid.dim == 2:
self.plots.append(MatplotlibPatchAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid,
bounding_box=bounding_box, vmin=vmin, vmax=vmax,
codim=codim, colorbar=separate_colorbars or i == len(U) - 1))
plot = MatplotlibPatchAxes(U=U, figure=figure, sync_timer=sync_timer, grid=grid, vmin=self.vmins,
vmax=self.vmaxs, bounding_box=bounding_box, codim=codim, columns=columns,
colorbar=True)
else:
assert count
self.plots.append(Matplotlib1DAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid,
vmin=vmin, vmax=vmax, count=count, codim=codim,
separate_plots=separate_plots))
plot = Matplotlib1DAxes(U=U, figure=figure, sync_timer=sync_timer, grid=grid, vmin=self.vmins,
vmax=self.vmaxs, columns=columns, codim=codim, separate_axes=separate_axes)
if self.legend:
ax.set_title(self.legend[i])
plt.tight_layout()
plot.ax[0].set_title(self.legend[0])
self.plots.append(plot)
if do_animation:
for fig_id in self.fig_ids:
......@@ -129,8 +145,9 @@ def visualize_patch(grid, U, bounding_box=([0, 0], [1, 1]), codim=2, title=None,
self.vmaxs = (max(np.max(u) for u in np_U),) * len(np_U)
def __init__(self):
super(Plot, self).__init__(U, grid, codim, legend, bounding_box=bounding_box,
separate_colorbars=separate_colorbars)
super(Plot, self).__init__(U, grid, codim, legend, bounding_box=bounding_box, columns=columns,
separate_colorbars=separate_colorbars, separate_plots=True,
separate_axes=False)
def set(self, ind):
np_U = self.U
......@@ -149,7 +166,8 @@ def visualize_patch(grid, U, bounding_box=([0, 0], [1, 1]), codim=2, title=None,
def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_plots=False, separate_axes=False, columns=2):
def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_plots=False, separate_axes=False,
columns=2):
"""Visualize scalar data associated to a one-dimensional |Grid| as a plot.
The grid's |ReferenceElement| must be the line. The data can either
......@@ -161,9 +179,10 @@ def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_
The underlying |Grid|.
U
|VectorArray| of the data to visualize. If `len(U) > 1`, the data is visualized
as a time series of plots. Alternatively, a tuple of |VectorArrays| can be
provided, in which case several plots are made into the same axes. The
lengths of all arrays have to agree.
as an animation in a single axes object or a series of axes, depending on the
`separate_axes` switch. It is also possible to provide a tuple of |VectorArrays|,
in which case several plots are made into one or multiple figures,
depending on the `separate_plots` switch. The lengths of all arrays have to agree.
codim
The codimension of the entities the data in `U` is attached to (either 0 or 1).
title
......@@ -172,9 +191,9 @@ def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_
Description of the data that is plotted. Most useful if `U` is a tuple in which
case `legend` has to be a tuple of strings of the same length.
separate_plots
If `True`, use subplots to visualize multiple |VectorArrays|.
If `True`, use multiple figures to visualize multiple |VectorArrays|.
separate_axes
If `True`, use separate axes for each subplot.
If `True`, use separate axes for each figure instead of an Animation.
column
Number of columns the subplots are organized in.
"""
......@@ -184,25 +203,26 @@ def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_
def _set_limits(self, np_U):
if separate_plots:
if separate_axes:
self.vmins = tuple(np.min(u) for u in np_U[0])
self.vmaxs = tuple(np.max(u) for u in np_U[0])
self.vmins = tuple(np.min(u) for u in np_U)
self.vmaxs = tuple(np.max(u) for u in np_U)
else:
self.vmins = (min(np.min(u) for u in np_U),) * len(np_U[0])
self.vmaxs = (max(np.max(u) for u in np_U),) * len(np_U[0])
self.vmins = (min(np.min(u) for u in np_U),) * len(np_U)
self.vmaxs = (max(np.max(u) for u in np_U),) * len(np_U)
else:
self.vmins = (min(np.min(u) for u in np_U[0]),)
self.vmaxs = (max(np.max(u) for u in np_U[0]),)
self.vmins = min(np.min(u) for u in np_U)
self.vmaxs = max(np.max(u) for u in np_U)
def __init__(self):
count = 1
super(Plot, self).__init__(U, grid, codim, legend, separate_plots=separate_plots, count=count)
super(Plot, self).__init__(U, grid, codim, legend, separate_plots=separate_plots, columns=columns,
separate_axes=separate_axes)
def set(self, ind):
np_U = self.U[ind]
if separate_plots:
for u, plot, vmin, vmax in zip(np_U, self.plots, self.vmins, self.vmaxs):
plot.set(u[np.newaxis, ...], vmin=vmin, vmax=vmax)
plot.set(u, vmin=vmin, vmax=vmax)
else:
self.plots[0].set(np_U, vmin=self.vmins[0], vmax=self.vmaxs[0])
return Plot()
self.plots[0].set(np_U, vmin=self.vmins, vmax=self.vmaxs)
pl = Plot()
pl.set(0)
return pl
......@@ -19,31 +19,43 @@ from pymor.discretizers.builtin.grids.referenceelements import triangle, square
class MatplotlibAxesBase:
def __init__(self, figure, sync_timer, grid, U=None, vmin=None, vmax=None, codim=2):
def __init__(self, figure, sync_timer, grid, U=None, vmin=None, vmax=None, codim=2, separate_axes=False, columns=2):
self.vmin = vmin
self.vmax = vmax
self.codim = codim
self.grid = grid
self.ax = figure.gca()
if separate_axes:
if len(U) == 1:
columns = 1 # otherwise we get a sep axes object with 0 data
rows = int(np.ceil(len(U) / columns))
self.ax = figure.subplots(rows, columns, squeeze=False).flatten()
else:
self.ax = (figure.gca(),)
self.figure = figure
self.codim = codim
self.grid = grid
self.separate_axes = separate_axes
self.count = len(U) if separate_axes else 1
self._plot_init()
# assignment delayed to ensure _plot_init works w/o data
self.U = U
# Rest is only needed with animation
if U is not None and len(U) > 1:
if not separate_axes and U is not None and len(U) > 1:
assert len(self.ax) == 1
delay_between_frames = 200 # ms
self.anim = animation.FuncAnimation(figure, self.set,
self.anim = animation.FuncAnimation(figure, self.animate,
frames=U, interval=delay_between_frames,
blit=True, event_source=sync_timer)
pad = (self.vmax - self.vmin) * 0.1
for ax in self.ax:
ax.set_ylim(self.vmin - pad, self.vmax + pad)
# generating the HTML instance outside this class causes the plot display to fail
self.html = HTML(self.anim.to_jshtml())
else:
self.set(self.U[0])
self.set(self.U)
@abstractmethod
def _plot_init(self):
......@@ -51,14 +63,19 @@ class MatplotlibAxesBase:
pass
@abstractmethod
def set(self, U, vmin=None, vmax=None):
def set(self, U):
"""Load new data into existing plot objects."""
pass
@abstractmethod
def animate(self, u):
"""Load new data into existing plot objects."""
pass
class MatplotlibPatchAxes(MatplotlibAxesBase):
def __init__(self, figure, grid, bounding_box=None, U=None, vmin=None, vmax=None, codim=2,
def __init__(self, figure, grid, bounding_box=None, U=None, vmin=None, vmax=None, codim=2, columns=2,
colorbar=True, sync_timer=None):
assert grid.reference_element in (triangle, square)
assert grid.dim == 2
......@@ -72,20 +89,20 @@ class MatplotlibPatchAxes(MatplotlibAxesBase):
self.reference_element = grid.reference_element
self.colorbar = colorbar
super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim, columns=columns,
sync_timer=sync_timer)
def _plot_init(self):
if self.codim == 2:
self.p = self.ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
self.p = self.ax[0].tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
np.zeros(len(self.coordinates)),
vmin=self.vmin, vmax=self.vmax, shading='gouraud')
else:
self.p = self.ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
self.p = self.ax[0].tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
facecolors=np.zeros(len(self.subentities)),
vmin=self.vmin, vmax=self.vmax, shading='flat')
if self.colorbar:
self.figure.colorbar(self.p, ax=self.ax)
self.figure.colorbar(self.p, ax=self.ax[0])
def set(self, U, vmin=None, vmax=None):
self.vmin = self.vmin if vmin is None else vmin
......@@ -102,15 +119,12 @@ class MatplotlibPatchAxes(MatplotlibAxesBase):
class Matplotlib1DAxes(MatplotlibAxesBase):
def __init__(self, U, figure, grid, count=1, vmin=None, vmax=None, codim=1, separate_plots=False,
sync_timer=None):
def __init__(self, U, figure, grid, vmin=None, vmax=None, codim=1, separate_axes=False, sync_timer=None,
columns=2):
assert isinstance(grid, OnedGrid)
assert codim in (0, 1)
self.count = count
self.separate_plots = separate_plots
super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
sync_timer=sync_timer)
super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim, columns=columns,
sync_timer=sync_timer, separate_axes=separate_axes)
def _plot_init(self):
centers = self.grid.centers(1)
......@@ -123,30 +137,36 @@ class Matplotlib1DAxes(MatplotlibAxesBase):
xs = centers
else:
xs = np.repeat(centers, 2)[1:-1]
lines = ()
for i in range(self.count):
l, = self.ax.plot(xs, np.zeros_like(xs))
lines = lines + (l,)
if self.separate_axes:
self.lines = [ax.plot(xs, np.zeros_like(xs))[0] for ax in self.ax]
else:
self.lines = [self.ax[0].plot(xs, np.zeros_like(xs))[0] for _ in range(self.count)]
pad = (self.vmax - self.vmin) * 0.1
self.ax.set_ylim(self.vmin - pad, self.vmax + pad)
self.lines = lines
for ax in self.ax:
ax.set_ylim(self.vmin - pad, self.vmax + pad)
def set(self, u, vmin=None, vmax=None):
self.vmin = self.vmin if vmin is None else vmin
self.vmax = self.vmax if vmax is None else vmax
for i in range(self.count):
if self.codim == 1:
if self.periodic:
self.lines[i].set_ydata(np.concatenate((u, [self.U[0]])))
else:
self.lines[i].set_ydata(u)
def _set(self, u, i):
if self.codim == 1:
if self.periodic:
self.lines[i].set_ydata(np.concatenate((u, [self.U[0]])))
else:
self.lines[i].set_ydata(np.repeat(u, 2))
self.lines[i].set_ydata(u)
else:
self.lines[i].set_ydata(np.repeat(u, 2))
pad = (self.vmax - self.vmin) * 0.1
self.ax.set_ylim(self.vmin - pad, self.vmax + pad)
def animate(self, u):
for i in range(len(self.ax)):
self._set(u, i)
return self.lines
def set(self, U, vmin=None, vmax=None):
self.vmin = self.vmin if vmin is None else vmin
self.vmax = self.vmax if vmax is None else vmax
for i, (u, ax) in enumerate(zip(U, self.ax)):
self._set(u, i)
pad = (self.vmax - self.vmin) * 0.1
ax.set_ylim(self.vmin - pad, self.vmax + pad)
if config.HAVE_QT and config.HAVE_MATPLOTLIB:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment