matplotlib.py 9.51 KB
Newer Older
1
# This file is part of the pyMOR project (http://www.pymor.org).
2
# Copyright 2013-2020 pyMOR developers and contributors. All rights reserved.
3
# License: BSD 2-Clause License (http://opensource.org/licenses/BSD-2-Clause)
4
import itertools
5
from pprint import pprint
6

7
import numpy as np
8 9
from IPython.core.display import display
from ipywidgets import HTML, HBox, widgets, Layout
10
import matplotlib.pyplot as plt
11

12
from pymor.core.config import config
13
from pymor.discretizers.builtin.gui.matplotlib import MatplotlibPatchAxes, Matplotlib1DAxes
14
from pymor.vectorarrays.interface import VectorArray
15 16


17 18
class MPLPlotBase:

19 20
    def __init__(self, U, grid, codim, legend, bounding_box=None, separate_colorbars=False, columns=2,
                 separate_plots=False, separate_axes=False):
21 22 23 24
        assert isinstance(U, VectorArray) \
               or (isinstance(U, tuple)
                   and all(isinstance(u, VectorArray) for u in U)
                   and all(len(u) == len(U[0]) for u in U))
25 26 27 28 29 30
        if separate_plots:
            self.fig_ids = (U.uid,) if isinstance(U, VectorArray) else tuple(u.uid for u in U)
        else:
            # using the same id multiple times lets us automagically re-use the same figure
            self.fig_ids = (U.uid,) if isinstance(U, VectorArray) else [U[0].uid] * len(U)
        self.U = U = (U.to_numpy().astype(np.float64, copy=False),) if isinstance(U, VectorArray) else \
31
            tuple(u.to_numpy().astype(np.float64, copy=False) for u in U)
32 33 34
        if grid.dim == 1 and len(U[0]) > 1 and not separate_plots:
            raise NotImplementedError('Plotting of VectorArrays with length > 1 is only available with '
                                      '`separate_plots=True`')
35 36 37 38 39 40

        if not config.HAVE_MATPLOTLIB:
            raise ImportError('cannot visualize: import of matplotlib failed')
        if not config.HAVE_IPYWIDGETS and len(U[0]) > 1:
            raise ImportError('cannot visualize: import of ipywidgets failed')
        self.legend = (legend,) if isinstance(legend, str) else legend
41
        assert self.legend is None or isinstance(self.legend, tuple) and len(self.legend) == len(U)
42 43 44 45 46 47
        self._set_limits(U)

        self.plots = []
        # this _supposed_ to let animations run in sync
        sync_timer = None

48
        do_animation = not separate_axes and len(U[0]) > 1
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        if separate_plots:
            for i, (vmin, vmax, u) in enumerate(zip(self.vmins, self.vmaxs, U)):
                figure = plt.figure(self.fig_ids[i])
                sync_timer = sync_timer or figure.canvas.new_timer()
                if grid.dim == 2:
                    plot = MatplotlibPatchAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid, vmin=vmin, vmax=vmax,
                                               bounding_box=bounding_box, codim=codim, columns=columns,
                                               colorbar=separate_colorbars or i == len(U) - 1)
                else:
                    plot = Matplotlib1DAxes(U=u, figure=figure, sync_timer=sync_timer, grid=grid, vmin=vmin, vmax=vmax,
                                            columns=columns, codim=codim, separate_axes=separate_axes)
                if self.legend:
                    plot.ax[0].set_title(self.legend[i])
                self.plots.append(plot)
                    # plt.tight_layout()
        else:
            figure = plt.figure(self.fig_ids[0])
67 68
            sync_timer = sync_timer or figure.canvas.new_timer()
            if grid.dim == 2:
69 70 71
                plot = MatplotlibPatchAxes(U=U, figure=figure, sync_timer=sync_timer, grid=grid, vmin=self.vmins,
                                           vmax=self.vmaxs, bounding_box=bounding_box, codim=codim, columns=columns,
                                           colorbar=True)
72
            else:
73 74
                plot = Matplotlib1DAxes(U=U, figure=figure, sync_timer=sync_timer, grid=grid, vmin=self.vmins,
                                        vmax=self.vmaxs, columns=columns, codim=codim, separate_axes=separate_axes)
75
            if self.legend:
76 77
                plot.ax[0].set_title(self.legend[0])
            self.plots.append(plot)
78 79 80 81 82

        if do_animation:
            for fig_id in self.fig_ids:
                # avoids figure double display
                plt.close(fig_id)
83 84
            html = [p.html for p in self.plots]
            template = """<div style="float: left; padding: 10px;">{0}</div>"""
85
            # IPython display system checks for presence and calls this func
86
            self._repr_html_ = lambda : '\n'.join(template.format(a._repr_html_()) for a in html)
87 88 89 90 91 92 93 94 95 96
        else:
            self._out = widgets.Output()
            with self._out:
                plt.show()
            # IPython display system checks for presence and calls this func
            self._ipython_display_ = self._out._ipython_display_




97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
def visualize_patch(grid, U, bounding_box=([0, 0], [1, 1]), codim=2, title=None, legend=None,
                    separate_colorbars=False, rescale_colorbars=False, columns=2):
    """Visualize scalar data associated to a two-dimensional |Grid| as a patch plot.

    The grid's |ReferenceElement| must be the triangle or square. The data can either
    be attached to the faces or vertices of the grid.

    Parameters
    ----------
    grid
        The underlying |Grid|.
    U
        |VectorArray| of the data to visualize. If `len(U) > 1`, the data is visualized
        as a time series of plots. Alternatively, a tuple of |VectorArrays| can be
        provided, in which case a subplot is created for each entry of the tuple. The
        lengths of all arrays have to agree.
    bounding_box
        A bounding box in which the grid is contained.
    codim
        The codimension of the entities the data in `U` is attached to (either 0 or 2).
    title
        Title of the plot.
    legend
        Description of the data that is plotted. Most useful if `U` is a tuple in which
        case `legend` has to be a tuple of strings of the same length.
    separate_colorbars
        If `True`, use separate colorbars for each subplot.
    rescale_colorbars
        If `True`, rescale colorbars to data in each frame.
    columns
        The number of columns in the visualizer GUI in case multiple plots are displayed
        at the same time.
    """

131
    class Plot(MPLPlotBase):
132

133
        def _set_limits(self, np_U):
134
            if separate_colorbars:
135
                # todo rescaling not set up
136
                if rescale_colorbars:
137 138
                    self.vmins = tuple(np.min(u[0]) for u in np_U)
                    self.vmaxs = tuple(np.max(u[0]) for u in np_U)
139
                else:
140 141
                    self.vmins = tuple(np.min(u) for u in np_U)
                    self.vmaxs = tuple(np.max(u) for u in np_U)
142 143
            else:
                if rescale_colorbars:
144 145
                    self.vmins = (min(np.min(u[0]) for u in np_U),) * len(np_U)
                    self.vmaxs = (max(np.max(u[0]) for u in np_U),) * len(np_U)
146
                else:
147 148 149 150
                    self.vmins = (min(np.min(u) for u in np_U),) * len(np_U)
                    self.vmaxs = (max(np.max(u) for u in np_U),) * len(np_U)

        def __init__(self):
151 152 153
            super(Plot, self).__init__(U, grid, codim, legend, bounding_box=bounding_box, columns=columns,
                                       separate_colorbars=separate_colorbars, separate_plots=True,
                                       separate_axes=False)
154 155

    return Plot()
156

157 158


159
def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_plots=True, separate_axes=False,
160
                            columns=2):
161 162 163 164 165 166 167 168 169 170 171
    """Visualize scalar data associated to a one-dimensional |Grid| as a plot.

    The grid's |ReferenceElement| must be the line. The data can either
    be attached to the subintervals or vertices of the grid.

    Parameters
    ----------
    grid
        The underlying |Grid|.
    U
        |VectorArray| of the data to visualize. If `len(U) > 1`, the data is visualized
172 173 174 175
        as an animation in a single axes object or a series of axes, depending on the
        `separate_axes` switch. It is also possible to provide a tuple of |VectorArrays|,
        in which case several plots are made into one or multiple figures,
        depending on the `separate_plots` switch. The lengths of all arrays have to agree.
176 177 178 179 180 181 182 183
    codim
        The codimension of the entities the data in `U` is attached to (either 0 or 1).
    title
        Title of the plot.
    legend
        Description of the data that is plotted. Most useful if `U` is a tuple in which
        case `legend` has to be a tuple of strings of the same length.
    separate_plots
184
        If `True`, use multiple figures to visualize multiple |VectorArrays|.
185
    separate_axes
186
        If `True`, use separate axes for each figure instead of an Animation.
187 188 189
    column
        Number of columns the subplots are organized in.
    """
190

191
    class Plot(MPLPlotBase):
192

193
        def _set_limits(self, np_U):
194
            if separate_plots:
195
                if separate_axes:
196 197
                    self.vmins = tuple(np.min(u) for u in np_U)
                    self.vmaxs = tuple(np.max(u) for u in np_U)
198
                else:
199 200
                    self.vmins = (min(np.min(u) for u in np_U),) * len(np_U)
                    self.vmaxs = (max(np.max(u) for u in np_U),) * len(np_U)
201
            else:
202 203
                self.vmins = min(np.min(u) for u in np_U)
                self.vmaxs = max(np.max(u) for u in np_U)
204

205
        def __init__(self):
206 207
            super(Plot, self).__init__(U, grid, codim, legend, separate_plots=separate_plots, columns=columns,
                                       separate_axes=separate_axes)
208

209
    return Plot()