matplotlib.py 9.06 KB
Newer Older
1
# This file is part of the pyMOR project (http://www.pymor.org).
2
# Copyright 2013-2020 pyMOR developers and contributors. All rights reserved.
Stephan Rave's avatar
Stephan Rave committed
3 4
# License: BSD 2-Clause License (http://opensource.org/licenses/BSD-2-Clause)

5
""" This module provides widgets for displaying plots of
6
scalar data assigned to one- and two-dimensional grids using
7
:mod:`matplotlib`. These widgets are not intended to be used directly.
8
"""
Stephan Rave's avatar
Stephan Rave committed
9

Stephan Rave's avatar
Stephan Rave committed
10
import numpy as np
11 12
from IPython.core.display import display, HTML
from matplotlib import animation, pyplot
René Fritze's avatar
René Fritze committed
13
from pymor.core.base import abstractmethod
Stephan Rave's avatar
Stephan Rave committed
14

15
from pymor.core.config import config
16 17
from pymor.discretizers.builtin.grids.constructions import flatten_grid
from pymor.discretizers.builtin.grids.referenceelements import triangle, square
Stephan Rave's avatar
Stephan Rave committed
18 19


René Fritze's avatar
René Fritze committed
20
class MatplotlibAxesBase:
21

René Fritze's avatar
René Fritze committed
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    def __init__(self, ax, figure, sync_timer, grid, U=None, vmin=None, vmax=None, codim=2):
        self.vmin = vmin
        self.vmax = vmax
        self.codim = codim
        self.ax = ax
        self.grid = grid
        # TODO plt.axes
        self.ax = ax
        self.figure = figure
        self.codim = codim
        self.grid = grid

        self._plot_init()

        # assignment delayed to ensure _plot_init works w/o data
        self.U = U
        # Rest is only needed with animation
        if U is not None and len(U) > 1:
            delay_between_frames = 200  # ms
            self.anim = animation.FuncAnimation(figure, self.set,
                                           frames=U, interval=delay_between_frames,
                                           blit=True, event_source=sync_timer)
            # generating the HTML instance outside this class causes the plot display to fail
            self.html = HTML(self.anim.to_jshtml())
46 47
        else:
            self.set(self.U[0])
René Fritze's avatar
René Fritze committed
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

    @abstractmethod
    def _plot_init(self):
        """Setup MPL figure display with empty data."""
        pass

    @abstractmethod
    def set(self, U, vmin=None, vmax=None):
        """Load new data into existing plot objects."""
        pass


class MatplotlibPatchAxes(MatplotlibAxesBase):

    def __init__(self, ax, figure, grid, bounding_box=None, U=None, vmin=None, vmax=None, codim=2,
                 colorbar=True, sync_timer=None):
64 65 66 67 68 69 70 71 72 73
        assert grid.reference_element in (triangle, square)
        assert grid.dim == 2
        assert codim in (0, 2)

        subentities, coordinates, entity_map = flatten_grid(grid)
        self.subentities = subentities if grid.reference_element is triangle \
            else np.vstack((subentities[:, 0:3], subentities[:, [2, 3, 0]]))
        self.coordinates = coordinates
        self.entity_map = entity_map
        self.reference_element = grid.reference_element
René Fritze's avatar
René Fritze committed
74 75 76 77
        self.colorbar = colorbar

        super().__init__(U=U, ax=ax, figure=figure, grid=grid,  vmin=vmin, vmax=vmax, codim=codim,
                         sync_timer=sync_timer)
78

René Fritze's avatar
René Fritze committed
79
    def _plot_init(self):
80
        if self.codim == 2:
René Fritze's avatar
René Fritze committed
81
            self.p = self.ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
82 83 84
                                 np.zeros(len(self.coordinates)),
                                 vmin=self.vmin, vmax=self.vmax, shading='gouraud')
        else:
René Fritze's avatar
René Fritze committed
85
            self.p = self.ax.tripcolor(self.coordinates[:, 0], self.coordinates[:, 1], self.subentities,
86 87
                                 facecolors=np.zeros(len(self.subentities)),
                                 vmin=self.vmin, vmax=self.vmax, shading='flat')
René Fritze's avatar
René Fritze committed
88 89
        if self.colorbar:
            self.figure.colorbar(self.p, ax=self.ax)
90

91
    def set(self, U, vmin=None, vmax=None):
92 93 94
        self.vmin = self.vmin if vmin is None else vmin
        self.vmax = self.vmax if vmax is None else vmax
        if self.codim == 2:
95
            self.p.set_array(U)
96
        elif self.reference_element is triangle:
97
            self.p.set_array(U)
98
        else:
99 100 101
            self.p.set_array(np.tile(U, 2))
        self.p.set_clim(self.vmin, self.vmax)
        return (self.p,)
102 103


René Fritze's avatar
René Fritze committed
104
class Matplotlib1DAxes(MatplotlibAxesBase):
105

106 107
    def __init__(self, U, ax, figure, grid, count=1, vmin=None, vmax=None, codim=1, separate_plots=False,
                 sync_timer=None):
108 109 110
        assert isinstance(grid, OnedGrid)
        assert codim in (0, 1)

111
        self.count = count
112
        self.separate_plots = separate_plots
René Fritze's avatar
René Fritze committed
113 114
        super().__init__(U=U, ax=ax, figure=figure, grid=grid, vmin=vmin, vmax=vmax, codim=codim,
                         sync_timer=sync_timer)
115

René Fritze's avatar
René Fritze committed
116 117 118 119
    def _plot_init(self):
        centers = self.grid.centers(1)
        if self.grid.identify_left_right:
            centers = np.concatenate((centers, [[self.grid.domain[1]]]), axis=0)
120 121 122
            self.periodic = True
        else:
            self.periodic = False
René Fritze's avatar
René Fritze committed
123
        if self.codim == 1:
124 125 126
            xs = centers
        else:
            xs = np.repeat(centers, 2)[1:-1]
127
        lines = ()
René Fritze's avatar
René Fritze committed
128 129
        for i in range(self.count):
            l, = self.ax.plot(xs, np.zeros_like(xs))
130
            lines = lines + (l,)
René Fritze's avatar
René Fritze committed
131 132
        pad = (self.vmax - self.vmin) * 0.1
        self.ax.set_ylim(self.vmin - pad, self.vmax + pad)
133
        self.lines = lines
134

135
    def set(self, u, vmin=None, vmax=None):
136 137
        self.vmin = self.vmin if vmin is None else vmin
        self.vmax = self.vmax if vmax is None else vmax
138 139 140
        for i in range(self.count):
            if self.codim == 1:
                if self.periodic:
141
                    self.lines[i].set_ydata(np.concatenate((u, [self.U[0]])))
142 143
                else:
                    self.lines[i].set_ydata(u)
144
            else:
145
                self.lines[i].set_ydata(np.repeat(u, 2))
146

147
        pad = (self.vmax - self.vmin) * 0.1
René Fritze's avatar
René Fritze committed
148
        self.ax.set_ylim(self.vmin - pad, self.vmax + pad)
149 150
        return self.lines

151 152


Stephan Rave's avatar
Stephan Rave committed
153
if config.HAVE_QT and config.HAVE_MATPLOTLIB:
154
    from Qt.QtWidgets import QSizePolicy
Stephan Rave's avatar
Stephan Rave committed
155

156 157
    import Qt
    if Qt.__qt_version__[0] == '4':
158
        from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas
159
    elif Qt.__qt_version__[0] == '5':
160 161 162
        from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
    else:
        raise NotImplementedError
Stephan Rave's avatar
Stephan Rave committed
163 164 165

    from matplotlib.figure import Figure

166
    from pymor.discretizers.builtin.grids.oned import OnedGrid
Stephan Rave's avatar
Stephan Rave committed
167

168
    # noinspection PyShadowingNames
Stephan Rave's avatar
Stephan Rave committed
169 170
    class Matplotlib1DWidget(FigureCanvas):

171
        def __init__(self, U, parent, grid, count, vmin=None, vmax=None, legend=None, codim=1,
172
                     separate_plots=False, dpi=100):
173
            assert isinstance(grid, OnedGrid)
Stephan Rave's avatar
Stephan Rave committed
174 175
            assert codim in (0, 1)

176 177 178 179
            figure = Figure(dpi=dpi)
            if not separate_plots:
                axes = figure.gca()
            self.codim = codim
Stephan Rave's avatar
Stephan Rave committed
180
            lines = ()
181
            centers = grid.centers(1)
182 183
            if grid.identify_left_right:
                centers = np.concatenate((centers, [[grid.domain[1]]]), axis=0)
184 185 186
                self.periodic = True
            else:
                self.periodic = False
187
            if codim == 1:
188
                xs = centers
189
            else:
190
                xs = np.repeat(centers, 2)[1:-1]
Stephan Rave's avatar
Stephan Rave committed
191
            for i in range(count):
192 193 194 195 196 197
                if separate_plots:
                    figure.add_subplot(int(count / 2) + count % 2, 2, i + 1)
                    axes = figure.gca()
                    pad = (vmax[i] - vmin[i]) * 0.1
                    axes.set_ylim(vmin[i] - pad, vmax[i] + pad)
                l, = axes.plot(xs, np.zeros_like(xs))
Stephan Rave's avatar
Stephan Rave committed
198
                lines = lines + (l,)
199 200 201 202 203 204 205
                if legend and separate_plots:
                    axes.legend([legend[i]])
            if not separate_plots:
                pad = (max(vmax) - min(vmin)) * 0.1
                axes.set_ylim(min(vmin) - pad, max(vmax) + pad)
                if legend:
                    axes.legend(legend)
Stephan Rave's avatar
Stephan Rave committed
206 207
            self.lines = lines

208
            super().__init__(figure)
Stephan Rave's avatar
Stephan Rave committed
209 210 211 212 213
            self.setParent(parent)
            self.setMinimumSize(300, 300)
            self.setSizePolicy(QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding))

        def set(self, U, ind):
Stephan Rave's avatar
Stephan Rave committed
214
            for l, u in zip(self.lines, U):
215
                if self.codim == 1:
216 217 218 219
                    if self.periodic:
                        l.set_ydata(np.concatenate((u[ind], [u[ind][0]])))
                    else:
                        l.set_ydata(u[ind])
220
                else:
Stephan Rave's avatar
Stephan Rave committed
221
                    l.set_ydata(np.repeat(u[ind], 2))
Stephan Rave's avatar
Stephan Rave committed
222 223 224 225 226 227 228 229 230 231
            self.draw()

    class MatplotlibPatchWidget(FigureCanvas):

        def __init__(self, parent, grid, bounding_box=None, vmin=None, vmax=None, codim=2, dpi=100):
            assert grid.reference_element in (triangle, square)
            assert grid.dim == 2
            assert codim in (0, 2)

            self.figure = Figure(dpi=dpi)
232
            super().__init__(self.figure)
Stephan Rave's avatar
Stephan Rave committed
233 234 235 236

            self.setParent(parent)
            self.setMinimumSize(300, 300)
            self.setSizePolicy(QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding))
237 238

            self.patch_axes = MatplotlibPatchAxes(self.figure, grid, bounding_box, vmin, vmax, codim)
Stephan Rave's avatar
Stephan Rave committed
239

240
        def set(self, U, vmin=None, vmax=None):
241
            self.patch_axes.set(U, vmin, vmax)
Stephan Rave's avatar
Stephan Rave committed
242 243 244 245
            self.draw()

else:

246
    class Matplotlib1DWidget:
Stephan Rave's avatar
Stephan Rave committed
247 248
        pass

249
    class MatplotlibPatchWidget:
Stephan Rave's avatar
Stephan Rave committed
250
        pass