Commit cca9096d authored by Tim Keil's avatar Tim Keil

[parameters.functionals] avoid nesting of products and sums similar to LincombOperator

parent a107b830
......@@ -45,39 +45,73 @@ class ParameterFunctional(ParametricObject):
def __call__(self, mu=None):
return self.evaluate(mu)
def __add__(self, other):
def _add_sub(self, other, sign):
if not isinstance(other, (ParameterFunctional, Number)):
return NotImplemented
if isinstance(other, Number):
if other == 0:
return self
other = ConstantParameterFunctional(other)
if isinstance(other, ParameterFunctional):
if isinstance(self, LincombParameterFunctional):
return self.with_(functionals=self.functionals + (other,),
coefficients=self.coefficients + (1,))
if not isinstance(self, LincombParameterFunctional):
if isinstance(other, LincombParameterFunctional):
functionals = (self,) + other.functionals
coefficients = (1.,) + (other.coefficients if sign == 1. else tuple(-c for c in other.coefficients))
else:
return LincombParameterFunctional([self, other], [1., 1.])
functionals, coefficients = (self, other), (1., sign)
elif isinstance(other, LincombParameterFunctional):
functionals = self.functionals + other.functionals
coefficients = self.coefficients + (other.coefficients if sign == 1.
else tuple(-c for c in other.coefficients))
else:
return NotImplemented
functionals, coefficients = self.functionals + (other,), self.coefficients + (sign,)
__radd__ = __add__
return LincombParameterFunctional(functionals, coefficients)
def __sub__(self, other):
if isinstance(other, ParameterFunctional):
if isinstance(self, LincombParameterFunctional):
return self.with_(functionals=self.functionals + (other,),
coefficients=self.coefficients + (-1,))
else:
return LincombParameterFunctional([self, other], [1., -1.])
def _radd_sub(self, other, sign):
if not isinstance(other, (ParameterFunctional, Number)):
return NotImplemented
if isinstance(other, Number):
if other == 0:
return self
other = ConstantParameterFunctional(other)
# note that 'other' can never be a LincombParameterFunctional
if not isinstance(self, LincombParameterFunctional):
functionals, coefficients = (other, self), (1., sign)
else:
return self + (- other)
functionals = (other,) + self.functionals
coefficients = (1.,) + (self.coefficients if sign == 1. else tuple(-c for c in self.coefficients))
return LincombParameterFunctional(functionals, coefficients)
def __add__(self, other):
return self._add_sub(other, 1.)
def __sub__(self, other):
return self._add_sub(other, -1.)
def __radd__(self, other):
return self._radd_sub(other, 1.)
def __rsub__(self, other):
return self._radd_sub(other, -1.)
def __mul__(self, other):
if not isinstance(other, (Number, ParameterFunctional)):
return NotImplemented
if isinstance(self, ProductParameterFunctional):
return self.with_(factors=self.factors + [other])
if not isinstance(self, ProductParameterFunctional):
if isinstance(other, ProductParameterFunctional):
return other.with_(factors=other.factors + [self])
else:
return ProductParameterFunctional([self, other])
elif isinstance(other, ProductParameterFunctional):
factors = self.factors + other.factors
return ProductParameterFunctional(factors)
else:
return ProductParameterFunctional([self, other])
return self.with_(factors=self.factors + [other])
__rmul__ = __mul__
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment