Commit 58183236 authored by Tim Keil's avatar Tim Keil

[functions] avoid nesting of products and sums (like in LincombParameterFunctional)

parent cca9096d
......@@ -52,7 +52,7 @@ class Function(ParametricObject):
"""Shorthand for :meth:`~Function.evaluate`."""
return self.evaluate(x, mu)
def __add__(self, other):
def _add_sub(self, other, sign):
if isinstance(other, Number) and other == 0:
return self
elif not isinstance(other, Function):
......@@ -61,27 +61,75 @@ class Function(ParametricObject):
if np.all(other == 0.):
return self
other = ConstantFunction(other, dim_domain=self.dim_domain)
return LincombFunction([self, other], [1., 1.])
__radd__ = __add__
if not isinstance(self, LincombFunction):
if isinstance(other, LincombFunction):
functions = (self,) + other.functions
coefficients = (1.,) + (other.coefficients if sign == 1. else tuple(-c for c in other.coefficients))
else:
functions, coefficients = (self, other), (1., sign)
elif isinstance(other, LincombFunction):
functions = self.functions + other.functions
coefficients = self.coefficients + (other.coefficients if sign == 1.
else tuple(-c for c in other.coefficients))
else:
functions, coefficients = self.functions + (other,), self.coefficients + (sign,)
def __sub__(self, other):
if isinstance(other, Function):
return LincombFunction([self, other], [1., -1.])
return LincombFunction(functions, coefficients)
def _radd_sub(self, other, sign):
if isinstance(other, Number) and other == 0:
return self
elif not isinstance(other, Function):
other = np.array(other)
assert other.shape == self.shape_range
if np.all(other == 0.):
return self
other = ConstantFunction(other, dim_domain=self.dim_domain)
# note that 'other' can never be a LincombFunction
if not isinstance(self, LincombFunction):
functions, coefficients = (other, self), (1., sign)
else:
return self + (- np.array(other))
functions = (other,) + self.functions
coefficients = (1.,) + (self.coefficients if sign == 1. else tuple(-c for c in self.coefficients))
return LincombFunction(functions, coefficients)
def __add__(self, other):
return self._add_sub(other, 1.)
def __sub__(self, other):
return self._add_sub(other, -1.)
def __radd__(self, other):
return self._radd_sub(other, 1.)
def __rsub__(self, other):
return self._radd_sub(other, -1.)
def __mul__(self, other):
if isinstance(other, (Number, ParameterFunctional)):
if not isinstance(other, (Number, ParameterFunctional, Function)):
return NotImplemented
if isinstance(other, ParameterFunctional):
return LincombFunction([self], [other])
if isinstance(other, Function):
return ProductFunction([self, other])
return NotImplemented
elif isinstance(other, Number):
other = ConstantFunction(other)
if not isinstance(self, ProductFunction):
if isinstance(other, ProductFunction):
return other.with_(functions=other.functions + [self])
else:
return ProductFunction([self, other])
elif isinstance(other, ProductFunction):
functions = self.functions + other.functions
return ProductFunction(functions)
else:
return self.with_(functions=self.functions + [other])
__rmul__ = __mul__
def __neg__(self):
return LincombFunction([self], [-1.])
return self * (-1.)
class ConstantFunction(Function):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment