Commit 43ae7a56 authored by René Fritze's avatar René Fritze Committed by René Fritze

[vis] simplify MPL object signatures

parent 9630ca8b
......@@ -13,20 +13,6 @@ from pymor.discretizers.builtin.gui.matplotlib import MatplotlibPatchAxes, Matpl
from pymor.vectorarrays.interface import VectorArray
class concat_display(object):
"""Display HTML representation of multiple objects"""
template = """<div style="float: left; padding: 10px;">{0}</div>"""
def __init__(self, *args):
self.args = args
def _repr_html_(self):
return '\n'.join(self.template.format(a._repr_html_()) for a in self.args)
def __repr__(self):
return '\n\n'.join(repr(a) for a in self.args)
class MPLPlotBase:
def __init__(self, U, grid, codim, legend, bounding_box=None, separate_colorbars=False, count=None,
......@@ -58,12 +44,12 @@ class MPLPlotBase:
ax = plt.axes()
sync_timer = sync_timer or figure.canvas.new_timer()
if grid.dim == 2:
self.plots.append(MatplotlibPatchAxes(U=u, ax=ax, figure=figure, sync_timer=sync_timer, grid=grid,
self.plots.append(MatplotlibPatchAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid,
bounding_box=bounding_box, vmin=vmin, vmax=vmax,
codim=codim, colorbar=separate_colorbars or i == len(U) - 1))
else:
assert count
self.plots.append(Matplotlib1DAxes(U=u, ax=ax, figure=figure, sync_timer=sync_timer, grid=grid,
self.plots.append(Matplotlib1DAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid,
vmin=vmin, vmax=vmax, count=count, codim=codim,
separate_plots=separate_plots))
if self.legend:
......@@ -75,9 +61,10 @@ class MPLPlotBase:
for fig_id in self.fig_ids:
# avoids figure double display
plt.close(fig_id)
self._cd = concat_display(*[p.html for p in self.plots])
html = [p.html for p in self.plots]
template = """<div style="float: left; padding: 10px;">{0}</div>"""
# IPython display system checks for presence and calls this func
self._repr_html_ = self._cd._repr_html_
self._repr_html_ = lambda : '\n'.join(template.format(a._repr_html_()) for a in html)
else:
self._out = widgets.Output()
with self._out:
......
......@@ -19,14 +19,13 @@ from pymor.discretizers.builtin.grids.referenceelements import triangle, square
class MatplotlibAxesBase:
def __init__(self, ax, figure, sync_timer, grid, U=None, vmin=None, vmax=None, codim=2):
def __init__(self, figure, sync_timer, grid, U=None, vmin=None, vmax=None, codim=2):
self.vmin = vmin
self.vmax = vmax
self.codim = codim
self.ax = ax
self.grid = grid
# TODO plt.axes
self.ax = ax
self.ax = figure.gca()
self.figure = figure
self.codim = codim
self.grid = grid
......@@ -59,7 +58,7 @@ class MatplotlibAxesBase:
class MatplotlibPatchAxes(MatplotlibAxesBase):
def __init__(self, ax, figure, grid, bounding_box=None, U=None, vmin=None, vmax=None, codim=2,
def __init__(self, figure, grid, bounding_box=None, U=None, vmin=None, vmax=None, codim=2,
colorbar=True, sync_timer=None):
assert grid.reference_element in (triangle, square)
assert grid.dim == 2
......@@ -73,7 +72,7 @@ class MatplotlibPatchAxes(MatplotlibAxesBase):
self.reference_element = grid.reference_element
self.colorbar = colorbar
super().__init__(U=U, ax=ax, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
sync_timer=sync_timer)
def _plot_init(self):
......@@ -103,14 +102,14 @@ class MatplotlibPatchAxes(MatplotlibAxesBase):
class Matplotlib1DAxes(MatplotlibAxesBase):
def __init__(self, U, ax, figure, grid, count=1, vmin=None, vmax=None, codim=1, separate_plots=False,
def __init__(self, U, figure, grid, count=1, vmin=None, vmax=None, codim=1, separate_plots=False,
sync_timer=None):
assert isinstance(grid, OnedGrid)
assert codim in (0, 1)
self.count = count
self.separate_plots = separate_plots
super().__init__(U=U, ax=ax, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
super().__init__(U=U, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
sync_timer=sync_timer)
def _plot_init(self):
......@@ -235,7 +234,8 @@ if config.HAVE_QT and config.HAVE_MATPLOTLIB:
self.setMinimumSize(300, 300)
self.setSizePolicy(QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding))
self.patch_axes = MatplotlibPatchAxes(self.figure, grid, bounding_box, vmin, vmax, codim)
self.patch_axes = MatplotlibPatchAxes(figure=self.figure, grid=grid, bounding_box=bounding_box,
vmin=vmin, vmax=vmax, codim=codim)
def set(self, U, vmin=None, vmax=None):
self.patch_axes.set(U, vmin, vmax)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment