Commit 79ef3c3f authored by René Fritze's avatar René Fritze Committed by René Fritze

[vis] refactor MPL axes

This mostly reduces code duplication between them
parent fd18d83b
......@@ -216,8 +216,8 @@ def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_
figure = plt.figure(i)
ax = plt.axes()
sync_timer = sync_timer or figure.canvas.new_timer()
self.plots.append(Matplotlib1DAxes(u, ax, figure, sync_timer, grid, count, vmin=vmin, vmax=vmax,
codim=codim))
self.plots.append(Matplotlib1DAxes(U=u, ax=ax, figure=figure, sync_timer=sync_timer, grid=grid,
vmin=vmin, vmax=vmax, count=count, codim=codim, ))
if legend:
ax.set_title(legend[i])
......
......@@ -10,16 +10,55 @@ scalar data assigned to one- and two-dimensional grids using
import numpy as np
from IPython.core.display import display, HTML
from matplotlib import animation, pyplot
from pymor.core.base import abstractmethod
from pymor.core.config import config
from pymor.discretizers.builtin.grids.constructions import flatten_grid
from pymor.discretizers.builtin.grids.referenceelements import triangle, square
class MatplotlibPatchAxes:
class MatplotlibAxesBase:
def __init__(self, U, ax, figure, sync_timer, grid, bounding_box=None, vmin=None, vmax=None, codim=2,
colorbar=True):
def __init__(self, ax, figure, sync_timer, grid, U=None, vmin=None, vmax=None, codim=2):
self.vmin = vmin
self.vmax = vmax
self.codim = codim
self.ax = ax
self.grid = grid
# TODO plt.axes
self.ax = ax
self.figure = figure
self.codim = codim
self.grid = grid
self._plot_init()
# assignment delayed to ensure _plot_init works w/o data
self.U = U
# Rest is only needed with animation
if U is not None and len(U) > 1:
delay_between_frames = 200 # ms
self.anim = animation.FuncAnimation(figure, self.set,
frames=U, interval=delay_between_frames,
blit=True, event_source=sync_timer)
# generating the HTML instance outside this class causes the plot display to fail
self.html = HTML(self.anim.to_jshtml())
@abstractmethod
def _plot_init(self):
"""Setup MPL figure display with empty data."""
pass
@abstractmethod
def set(self, U, vmin=None, vmax=None):
"""Load new data into existing plot objects."""
pass
class MatplotlibPatchAxes(MatplotlibAxesBase):
def __init__(self, ax, figure, grid, bounding_box=None, U=None, vmin=None, vmax=None, codim=2,
colorbar=True, sync_timer=None):
assert grid.reference_element in (triangle, square)
assert grid.dim == 2
assert codim in (0, 2)
......@@ -30,28 +69,22 @@ class MatplotlibPatchAxes:
self.coordinates = coordinates
self.entity_map = entity_map
self.reference_element = grid.reference_element
self.vmin = vmin
self.vmax = vmax
self.codim = codim
self.U = U
self.colorbar = colorbar
super().__init__(U=U, ax=ax, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
sync_timer=sync_timer)
def _plot_init(self):
if self.codim == 2:
self.p = ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
self.p = self.ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
np.zeros(len(self.coordinates)),
vmin=self.vmin, vmax=self.vmax, shading='gouraud')
else:
self.p = ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
self.p = self.ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
facecolors=np.zeros(len(self.subentities)),
vmin=self.vmin, vmax=self.vmax, shading='flat')
if colorbar:
figure.colorbar(self.p, ax=ax)
delay_between_frames = 200 # ms
self.anim = animation.FuncAnimation(figure, self.set,
frames=U, interval=delay_between_frames,
blit=True, event_source=sync_timer)
# generating the HTML instance outside this class causes the plot display to fail
self.html = HTML(self.anim.to_jshtml())
if self.colorbar:
self.figure.colorbar(self.p, ax=self.ax)
def set(self, U, vmin=None, vmax=None):
self.vmin = self.vmin if vmin is None else vmin
......@@ -66,46 +99,35 @@ class MatplotlibPatchAxes:
return (self.p,)
class Matplotlib1DAxes:
class Matplotlib1DAxes(MatplotlibAxesBase):
def __init__(self, U, axes, figure, sync_timer, grid, count, vmin=None, vmax=None, codim=1):
def __init__(self, U, ax, figure, grid, count=1, vmin=None, vmax=None, codim=1, sync_timer=None):
assert isinstance(grid, OnedGrid)
assert codim in (0, 1)
self.codim = codim
self.grid = grid
self.vmin = vmin
self.vmax = vmax
self.count = count
self.U = U
super().__init__(U=U, ax=ax, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
sync_timer=sync_timer)
centers = grid.centers(1)
if grid.identify_left_right:
centers = np.concatenate((centers, [[grid.domain[1]]]), axis=0)
def _plot_init(self):
centers = self.grid.centers(1)
if self.grid.identify_left_right:
centers = np.concatenate((centers, [[self.grid.domain[1]]]), axis=0)
self.periodic = True
else:
self.periodic = False
if codim == 1:
if self.codim == 1:
xs = centers
else:
xs = np.repeat(centers, 2)[1:-1]
lines = ()
for i in range(count):
l, = axes.plot(xs, np.zeros_like(xs))
for i in range(self.count):
l, = self.ax.plot(xs, np.zeros_like(xs))
lines = lines + (l,)
pad = (vmax - vmin) * 0.1
axes.set_ylim(vmin - pad, vmax + pad)
self.axes = axes
pad = (self.vmax - self.vmin) * 0.1
self.ax.set_ylim(self.vmin - pad, self.vmax + pad)
self.lines = lines
delay_between_frames = 200 # ms
self.anim = animation.FuncAnimation(figure, self.set,
frames=U, interval=delay_between_frames,
blit=True, event_source=sync_timer)
# generating the HTML instance outside this class causes the plot display to fail
self.html = HTML(self.anim.to_jshtml())
def set(self, u, vmin=None, vmax=None):
self.vmin = self.vmin if vmin is None else vmin
self.vmax = self.vmax if vmax is None else vmax
......@@ -119,7 +141,7 @@ class Matplotlib1DAxes:
self.lines[i].set_ydata(np.repeat(u, 2))
pad = (self.vmax - self.vmin) * 0.1
self.axes.set_ylim(self.vmin - pad, self.vmax + pad)
self.ax.set_ylim(self.vmin - pad, self.vmax + pad)
return self.lines
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment