Commit 8f7ae9e7 authored by Tim Keil's avatar Tim Keil

[operators] move add sub and mul of LincombOperator to the base class

parent 8b54f1db
......@@ -221,57 +221,6 @@ class LincombOperator(Operator):
def as_source_array(self, mu=None):
return self._as_array(True, mu)
def _add_sub(self, other, sign):
if not isinstance(other, Operator):
return NotImplemented
if self.name != 'LincombOperator':
if isinstance(other, LincombOperator) and other.name == 'LincombOperator':
operators = (self,) + other.operators
coefficients = (1.,) + (other.coefficients if sign == 1. else tuple(-c for c in other.coefficients))
else:
operators, coefficients = (self, other), (1., sign)
elif isinstance(other, LincombOperator) and other.name == 'LincombOperator':
operators = self.operators + other.operators
coefficients = self.coefficients + (other.coefficients if sign == 1.
else tuple(-c for c in other.coefficients))
else:
operators, coefficients = self.operators + (other,), self.coefficients + (sign,)
return LincombOperator(operators, coefficients, solver_options=self.solver_options)
def _radd_sub(self, other, sign):
if not isinstance(other, Operator):
return NotImplemented
# note that 'other' can never be a LincombOperator
if self.name != 'LincombOperator':
operators, coefficients = (other, self), (1., sign)
else:
operators = (other,) + self.operators
coefficients = (1.,) + (self.coefficients if sign == 1. else tuple(-c for c in self.coefficients))
return LincombOperator(operators, coefficients, solver_options=other.solver_options)
def __add__(self, other):
return self._add_sub(other, 1.)
def __sub__(self, other):
return self._add_sub(other, -1.)
def __radd__(self, other):
return self._radd_sub(other, 1.)
def __rsub__(self, other):
return self._radd_sub(other, -1.)
def __mul__(self, other):
assert isinstance(other, (Number, ParameterFunctional))
if self.name != 'LincombOperator':
return LincombOperator((self,), (other,))
else:
return self.with_(coefficients=tuple(c * other for c in self.coefficients))
class ConcatenationOperator(Operator):
"""|Operator| representing the concatenation of two |Operators|.
......
......@@ -520,35 +520,61 @@ class Operator(ParametricObject):
"""
raise NotImplementedError
def __add__(self, other):
"""Sum of two operators."""
if other == 0:
return self
def _add_sub(self, other, sign):
if not isinstance(other, Operator):
return NotImplemented
from pymor.operators.constructions import LincombOperator
if isinstance(other, LincombOperator):
return NotImplemented
if self.name != 'LincombOperator':
if isinstance(other, LincombOperator) and other.name == 'LincombOperator':
operators = (self,) + other.operators
coefficients = (1.,) + (other.coefficients if sign == 1. else tuple(-c for c in other.coefficients))
else:
operators, coefficients = (self, other), (1., sign)
elif isinstance(other, LincombOperator) and other.name == 'LincombOperator':
assert isinstance(self, LincombOperator)
operators = self.operators + other.operators
coefficients = self.coefficients + (other.coefficients if sign == 1.
else tuple(-c for c in other.coefficients))
else:
return LincombOperator([self, other], [1., 1.])
assert isinstance(self, LincombOperator)
operators, coefficients = self.operators + (other,), self.coefficients + (sign,)
__radd__ = __add__
return LincombOperator(operators, coefficients, solver_options=self.solver_options)
def __sub__(self, other):
def _radd_sub(self, other, sign):
if not isinstance(other, Operator):
return NotImplemented
from pymor.operators.constructions import LincombOperator
if isinstance(other, LincombOperator):
return NotImplemented
# note that 'other' can never be a LincombOperator
if self.name != 'LincombOperator':
operators, coefficients = (other, self), (1., sign)
else:
return LincombOperator([self, other], [1., -1.])
assert isinstance(self, LincombOperator)
operators = (other,) + self.operators
coefficients = (1.,) + (self.coefficients if sign == 1. else tuple(-c for c in self.coefficients))
return LincombOperator(operators, coefficients, solver_options=other.solver_options)
def __add__(self, other):
return self._add_sub(other, 1.)
def __sub__(self, other):
return self._add_sub(other, -1.)
def __radd__(self, other):
return self._radd_sub(other, 1.)
def __rsub__(self, other):
return self._radd_sub(other, -1.)
def __mul__(self, other):
"""Product of operator by a scalar."""
if not isinstance(other, (Number, ParameterFunctional)):
return NotImplemented
assert isinstance(other, (Number, ParameterFunctional))
from pymor.operators.constructions import LincombOperator
return LincombOperator([self], [other])
if self.name != 'LincombOperator':
return LincombOperator((self,), (other,))
else:
assert isinstance(self, LincombOperator)
return self.with_(coefficients=tuple(c * other for c in self.coefficients))
def __rmul__(self, other):
return self * other
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment