Commit 381dad21 authored by René Fritze's avatar René Fritze Committed by René Fritze

[vis] fix for 1D MPL same plots/axes

parent 9242f951
......@@ -29,6 +29,9 @@ class MPLPlotBase:
self.fig_ids = (U.uid,) if isinstance(U, VectorArray) else [U[0].uid] * len(U)
self.U = U = (U.to_numpy().astype(np.float64, copy=False),) if isinstance(U, VectorArray) else \
tuple(u.to_numpy().astype(np.float64, copy=False) for u in U)
if grid.dim == 1 and len(U[0]) > 1 and not separate_plots:
raise NotImplementedError('Plotting of VectorArrays with length > 1 is only available with '
'`separate_plots=True`')
if not config.HAVE_MATPLOTLIB:
raise ImportError('cannot visualize: import of matplotlib failed')
......@@ -166,7 +169,7 @@ def visualize_patch(grid, U, bounding_box=([0, 0], [1, 1]), codim=2, title=None,
def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_plots=False, separate_axes=False,
def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_plots=True, separate_axes=False,
columns=2):
"""Visualize scalar data associated to a one-dimensional |Grid| as a plot.
......@@ -223,6 +226,5 @@ def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_
plot.set(u, vmin=vmin, vmax=vmax)
else:
self.plots[0].set(np_U, vmin=self.vmins, vmax=self.vmaxs)
pl = Plot()
pl.set(0)
return pl
return Plot()
......@@ -36,14 +36,14 @@ class MatplotlibAxesBase:
self.codim = codim
self.grid = grid
self.separate_axes = separate_axes
self.count = len(U) if separate_axes else 1
self.count = len(U) if separate_axes or isinstance(U, tuple) else 1
self._plot_init()
# assignment delayed to ensure _plot_init works w/o data
self.U = U
# Rest is only needed with animation
if not separate_axes and U is not None and len(U) > 1:
if not separate_axes and self.count == 1:
assert len(self.ax) == 1
delay_between_frames = 200 # ms
self.anim = animation.FuncAnimation(figure, self.animate,
......@@ -162,9 +162,15 @@ class Matplotlib1DAxes(MatplotlibAxesBase):
def set(self, U, vmin=None, vmax=None):
self.vmin = self.vmin if vmin is None else vmin
self.vmax = self.vmax if vmax is None else vmax
for i, (u, ax) in enumerate(zip(U, self.ax)):
self._set(u, i)
pad = (self.vmax - self.vmin) * 0.1
if isinstance(U, tuple):
for i, u in enumerate(U):
self._set(u, i)
else:
for i, (u, _) in enumerate(zip(U, self.ax)):
self._set(u, i)
pad = (self.vmax - self.vmin) * 0.1
for ax in self.ax:
ax.set_ylim(self.vmin - pad, self.vmax + pad)
......
......@@ -117,7 +117,7 @@ class OnedVisualizer(BasicObject):
backend = backend or ('jupyter' if is_jupyter() else None)
self.__auto_init(locals())
def visualize(self, U, m, title=None, legend=None, separate_plots=False,
def visualize(self, U, m, title=None, legend=None, separate_plots=True,
separate_axes=False, block=None, filename=None, columns=2):
"""Visualize the provided data.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment