matplotlib.py 9.05 KB
Newer Older
1
# This file is part of the pyMOR project (http://www.pymor.org).
2
# Copyright 2013-2020 pyMOR developers and contributors. All rights reserved.
3
# License: BSD 2-Clause License (http://opensource.org/licenses/BSD-2-Clause)
4 5
import itertools

6
import numpy as np
7
from ipywidgets import HTML, HBox
8
import matplotlib.pyplot as plt
9

10
from pymor.core.config import config
11
from pymor.discretizers.builtin.gui.matplotlib import MatplotlibPatchAxes, Matplotlib1DAxes
12
from pymor.vectorarrays.interface import VectorArray
13 14


15 16 17 18 19 20 21 22 23 24 25 26 27 28
class concat_display(object):
    """Display HTML representation of multiple objects"""
    template = """<div style="float: left; padding: 10px;">{0}</div>"""

    def __init__(self, *args):
        self.args = args

    def _repr_html_(self):
        return '\n'.join(self.template.format(a._repr_html_()) for a in self.args)

    def __repr__(self):
        return '\n\n'.join(repr(a) for a in self.args)


29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
def visualize_patch(grid, U, bounding_box=([0, 0], [1, 1]), codim=2, title=None, legend=None,
                    separate_colorbars=False, rescale_colorbars=False, columns=2):
    """Visualize scalar data associated to a two-dimensional |Grid| as a patch plot.

    The grid's |ReferenceElement| must be the triangle or square. The data can either
    be attached to the faces or vertices of the grid.

    Parameters
    ----------
    grid
        The underlying |Grid|.
    U
        |VectorArray| of the data to visualize. If `len(U) > 1`, the data is visualized
        as a time series of plots. Alternatively, a tuple of |VectorArrays| can be
        provided, in which case a subplot is created for each entry of the tuple. The
        lengths of all arrays have to agree.
    bounding_box
        A bounding box in which the grid is contained.
    codim
        The codimension of the entities the data in `U` is attached to (either 0 or 2).
    title
        Title of the plot.
    legend
        Description of the data that is plotted. Most useful if `U` is a tuple in which
        case `legend` has to be a tuple of strings of the same length.
    separate_colorbars
        If `True`, use separate colorbars for each subplot.
    rescale_colorbars
        If `True`, rescale colorbars to data in each frame.
    columns
        The number of columns in the visualizer GUI in case multiple plots are displayed
        at the same time.
    """

63
    assert isinstance(U, VectorArray) \
64
        or (isinstance(U, tuple)
65
            and all(isinstance(u, VectorArray) for u in U)
66
            and all(len(u) == len(U[0]) for u in U))
67
    U = (U.to_numpy().astype(np.float64, copy=False),) if isinstance(U, VectorArray) else \
68
        tuple(u.to_numpy().astype(np.float64, copy=False) for u in U)
69 70 71 72 73 74

    if not config.HAVE_MATPLOTLIB:
        raise ImportError('cannot visualize: import of matplotlib failed')
    if not config.HAVE_IPYWIDGETS and len(U[0]) > 1:
        raise ImportError('cannot visualize: import of ipywidgets failed')

75 76 77 78 79 80 81 82 83 84
    if isinstance(legend, str):
        legend = (legend,)
    assert legend is None or isinstance(legend, tuple) and len(legend) == len(U)
    if len(U) < 2:
        columns = 1

    class Plot:

        def __init__(self):
            if separate_colorbars:
85
                # todo rescaling not set up
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
                if rescale_colorbars:
                    self.vmins = tuple(np.min(u[0]) for u in U)
                    self.vmaxs = tuple(np.max(u[0]) for u in U)
                else:
                    self.vmins = tuple(np.min(u) for u in U)
                    self.vmaxs = tuple(np.max(u) for u in U)
            else:
                if rescale_colorbars:
                    self.vmins = (min(np.min(u[0]) for u in U),) * len(U)
                    self.vmaxs = (max(np.max(u[0]) for u in U),) * len(U)
                else:
                    self.vmins = (min(np.min(u) for u in U),) * len(U)
                    self.vmaxs = (max(np.max(u) for u in U),) * len(U)


            rows = int(np.ceil(len(U) / columns))
            self.figure = figure = plt.figure()

            self.plots = plots = []
105 106 107 108
            axes = plt.subplots(nrows=rows, ncols=columns, squeeze=False)
            coord = itertools.product(range(rows), range(columns))
            for i, (vmin, vmax, u, c) in enumerate(zip(self.vmins, self.vmaxs, U, coord)):
                ax = axes[c]
109

110
                plots.append(MatplotlibPatchAxes(U=u, ax=ax, figure=figure, grid=grid, bounding_box=bounding_box, vmin=vmin, vmax=vmax,
111
                                                 codim=codim, colorbar=separate_colorbars or i == len(U)-1))
112 113 114
                if legend:
                    ax.set_title(legend[i])

115

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
        def set(self, U, ind):
            if rescale_colorbars:
                if separate_colorbars:
                    self.vmins = tuple(np.min(u[ind]) for u in U)
                    self.vmaxs = tuple(np.max(u[ind]) for u in U)
                else:
                    self.vmins = (min(np.min(u[ind]) for u in U),) * len(U)
                    self.vmaxs = (max(np.max(u[ind]) for u in U),) * len(U)

            for u, plot, vmin, vmax in zip(U, self.plots, self.vmins, self.vmaxs):
                plot.set(u[ind], vmin=vmin, vmax=vmax)

    plot = Plot()

    if len(U[0]) > 1:
131 132 133
        # otherwise the subplot displays twice
        plt.close(plot.figure)
        return concat_display(*[p.html for p in plot.plots])
134

135
    return plot
136

137 138
        # otherwise the subplot displays twice
        plt.close(plot.figure)
139
        return concat_display(*[p.html for p in plot.plots])
140

141
    return plot
142

143 144


145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
def visualize_matplotlib_1d(grid, U, codim=1, title=None, legend=None, separate_plots=False, separate_axes=False, columns=2):
    """Visualize scalar data associated to a one-dimensional |Grid| as a plot.

    The grid's |ReferenceElement| must be the line. The data can either
    be attached to the subintervals or vertices of the grid.

    Parameters
    ----------
    grid
        The underlying |Grid|.
    U
        |VectorArray| of the data to visualize. If `len(U) > 1`, the data is visualized
        as a time series of plots. Alternatively, a tuple of |VectorArrays| can be
        provided, in which case several plots are made into the same axes. The
        lengths of all arrays have to agree.
    codim
        The codimension of the entities the data in `U` is attached to (either 0 or 1).
    title
        Title of the plot.
    legend
        Description of the data that is plotted. Most useful if `U` is a tuple in which
        case `legend` has to be a tuple of strings of the same length.
    separate_plots
        If `True`, use subplots to visualize multiple |VectorArrays|.
    separate_axes
        If `True`, use separate axes for each subplot.
    column
        Number of columns the subplots are organized in.
    """
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    assert isinstance(U, VectorArray) \
        or (isinstance(U, tuple)
            and all(isinstance(u, VectorArray) for u in U)
            and all(len(u) == len(U[0]) for u in U))
    U = (U.to_numpy(),) if isinstance(U, VectorArray) else tuple(u.to_numpy() for u in U)

    if not config.HAVE_MATPLOTLIB:
        raise ImportError('cannot visualize: import of matplotlib failed')
    if not config.HAVE_IPYWIDGETS and len(U[0]) > 1:
        raise ImportError('cannot visualize: import of ipywidgets failed')

    if isinstance(legend, str):
        legend = (legend,)
    assert legend is None or isinstance(legend, tuple) and len(legend) == len(U)


    class Plot:

        def __init__(self):
            if separate_plots:
194 195 196 197 198 199
                if separate_axes:
                    self.vmins = tuple(np.min(u) for u in U[0])
                    self.vmaxs = tuple(np.max(u) for u in U[0])
                else:
                    self.vmins = (min(np.min(u) for u in U),) * len(U[0])
                    self.vmaxs = (max(np.max(u) for u in U),) * len(U[0])
200
            else:
201 202
                self.vmins = (min(np.min(u) for u in U[0]),)
                self.vmaxs = (max(np.max(u) for u in U[0]),)
203 204 205

            import matplotlib.pyplot as plt

206 207 208 209
            if separate_axes:
                rows = int(np.ceil(len(U[0]) / columns))
            else:
                rows = int(np.ceil(len(U) / columns))
210 211

            self.plots = []
212 213 214 215

            self.figure, axes = plt.subplots(nrows=rows, ncols=columns, squeeze=False)
            coord = itertools.product(range(rows), range(columns))
            for i, (vmin, vmax, u, c) in enumerate(zip(self.vmins, self.vmaxs, U, coord)):
216 217 218
                count = 1
                if not separate_plots:
                    count = len(U[0])
219
                ax = axes[c]
220
                self.plots.append(Matplotlib1DAxes(u, ax, self.figure, grid, count, vmin=vmin, vmax=vmax,
221 222 223 224 225 226 227 228 229 230 231 232 233 234
                                                   codim=codim))
                if legend:
                    ax.set_title(legend[i])

            plt.tight_layout()

        def set(self, U):
            if separate_plots:
                for u, plot, vmin, vmax in zip(U, self.plots, self.vmins, self.vmaxs):
                    plot.set(u[np.newaxis, ...], vmin=vmin, vmax=vmax)
            else:
                self.plots[0].set(U, vmin=self.vmins[0], vmax=self.vmaxs[0])

    plot = Plot()
235 236

    if len(U[0]) > 1: