Unverified Commit fbb3ff6c authored by Stephan Rave's avatar Stephan Rave Committed by GitHub
Browse files

Merge pull request #1731 from pymor/expressions_ufl_fix

Fix Expression.to_fenics for conditionals
parents f3576303 fc473ed0
Pipeline #156368 passed with stages
in 25 minutes and 55 seconds
......@@ -424,7 +424,10 @@ class BinaryOp(Expression):
if not _broadcastable_shapes(first.shape, second.shape):
raise ValueError(f'Incompatible shapes of expressions "{first}" and "{second}" with shapes '
f'{first.shape} and {second.shape} for binary operator {self.numpy_symbol}')
return np.vectorize(ufl_op)(first, second)
if self.fenics_conditional:
return np.vectorize(lambda x, y: ufl.conditional(ufl_op(x, y), 1., 0.))(first, second)
else:
return np.vectorize(ufl_op)(first, second)
def __str__(self):
return f'({self.first} {self.numpy_symbol} {self.second})'
......@@ -562,18 +565,18 @@ def _broadcastable_shapes(first, second):
return all(f == s or f == 1 or s == 1 for f, s in zip(first[::-1], second[::-1]))
class Sum(BinaryOp): numpy_symbol = '+'; fenics_symbol = operator.add # NOQA
class Diff(BinaryOp): numpy_symbol = '-'; fenics_symbol = operator.sub # NOQA
class Prod(BinaryOp): numpy_symbol = '*'; fenics_symbol = operator.mul # NOQA
class Div(BinaryOp): numpy_symbol = '/'; fenics_symbol = operator.truediv # NOQA
class Sum(BinaryOp): numpy_symbol = '+'; fenics_symbol = operator.add; fenics_conditional=False # NOQA
class Diff(BinaryOp): numpy_symbol = '-'; fenics_symbol = operator.sub; fenics_conditional=False # NOQA
class Prod(BinaryOp): numpy_symbol = '*'; fenics_symbol = operator.mul; fenics_conditional=False # NOQA
class Div(BinaryOp): numpy_symbol = '/'; fenics_symbol = operator.truediv; fenics_conditional=False # NOQA
class Pow(BinaryOp): numpy_symbol = '**'; fenics_symbol = 'elem_pow' # NOQA
class LE(BinaryOp): numpy_symbol = '<='; fenics_symbol = 'le' # NOQA
class GE(BinaryOp): numpy_symbol = '>='; fenics_symbol = 'ge' # NOQA
class LT(BinaryOp): numpy_symbol = '<'; fenics_symbol = 'lt' # NOQA
class GT(BinaryOp): numpy_symbol = '>'; fenics_symbol = 'gt' # NOQA
class Mod(BinaryOp): numpy_symbol = '%'; fenics_symbol = None # NOQA
class Pow(BinaryOp): numpy_symbol = '**'; fenics_symbol = 'elem_pow'; fenics_conditional=False # NOQA
class LE(BinaryOp): numpy_symbol = '<='; fenics_symbol = 'le'; fenics_conditional=True # NOQA
class GE(BinaryOp): numpy_symbol = '>='; fenics_symbol = 'ge'; fenics_conditional=True # NOQA
class LT(BinaryOp): numpy_symbol = '<'; fenics_symbol = 'lt'; fenics_conditional=True # NOQA
class GT(BinaryOp): numpy_symbol = '>'; fenics_symbol = 'gt'; fenics_conditional=True # NOQA
class Mod(BinaryOp): numpy_symbol = '%'; fenics_symbol = None; fenics_conditional=None # NOQA
class sin(UnaryFunctionCall): numpy_symbol = 'sin'; fenics_symbol = 'sin' # NOQA
......
......@@ -24,43 +24,6 @@ from pymor.vectorarrays.list import CopyOnWriteVector, ComplexifiedVector, Compl
from pymor.vectorarrays.numpy import NumpyVectorSpace
@defaults('doit')
def patch_ufl(doit=True):
"""Monkey patch ufl.algorithms.estimate_total_polynomial_degree.
Catches `TypeError`, which can be called by certain UFL expressions, and returns
`default_degree`.
This is needed, for instance, when using :mod:`pymor.discretizers.fenics` on a
:func:`~pymor.analyticalproblems.thermalblock.thermal_block_problem`.
"""
if not doit:
return
import ufl
real_estimate_total_polynomial_degree = ufl.algorithms.estimate_total_polynomial_degree
def estimate_total_polynomial_degree_wrapper(e, default_degree=1, element_replace_map={}):
try:
return real_estimate_total_polynomial_degree(e, default_degree=default_degree,
element_replace_map=element_replace_map)
except TypeError:
return default_degree
ufl.algorithms.estimate_degrees.estimate_total_polynomial_degree = estimate_total_polynomial_degree_wrapper
ufl.algorithms.estimate_total_polynomial_degree = estimate_total_polynomial_degree_wrapper
# use sys.modules for monkey patching since compute_form_data is at the same time function
# and sub-module
import sys
sys.modules['ufl.algorithms.compute_form_data'].estimate_total_polynomial_degree \
= estimate_total_polynomial_degree_wrapper
patch_ufl()
@unpicklable
class FenicsVector(CopyOnWriteVector):
"""Wraps a FEniCS vector to make it usable with ListVectorArray."""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment