Commit 524cfcbd authored by Stephan Rave's avatar Stephan Rave

[models] implement solve()/output()/error_estimate() in terms of compute()

parent c29a64d6
Pipeline #65598 failed with stages
in 45 minutes and 10 seconds
...@@ -333,9 +333,10 @@ def _compute_errors(mu, fom, reductor, error_estimator, error_norms, condition, ...@@ -333,9 +333,10 @@ def _compute_errors(mu, fom, reductor, error_estimator, error_norms, condition,
for i_N, N in enumerate(basis_sizes): for i_N, N in enumerate(basis_sizes):
rom = reductor.reduce(dims={k: N for k in reductor.bases}) rom = reductor.reduce(dims={k: N for k in reductor.bases})
u = rom.solve(mu) result = rom.compute(solution=True, solution_error_estimate=error_estimator, mu=mu)
u = result['solution']
if error_estimator: if error_estimator:
e = rom.estimate_error(u, mu) e = result['solution_error_estimate']
e = e[0] if hasattr(e, '__len__') else e e = e[0] if hasattr(e, '__len__') else e
error_estimates[i_N] = e error_estimates[i_N] = e
if fom and reductor: if fom and reductor:
......
...@@ -268,7 +268,7 @@ def _rb_surrogate_evaluate(rom=None, fom=None, reductor=None, mus=None, error_no ...@@ -268,7 +268,7 @@ def _rb_surrogate_evaluate(rom=None, fom=None, reductor=None, mus=None, error_no
return -1., None return -1., None
if fom is None: if fom is None:
errors = [rom.estimate_error(rom.solve(mu), mu) for mu in mus] errors = [rom.estimate_error(mu) for mu in mus]
elif error_norm is not None: elif error_norm is not None:
errors = [error_norm(fom.solve(mu) - reductor.reconstruct(rom.solve(mu))) for mu in mus] errors = [error_norm(fom.solve(mu) - reductor.reconstruct(rom.solve(mu))) for mu in mus]
else: else:
......
...@@ -84,18 +84,8 @@ class StationaryModel(Model): ...@@ -84,18 +84,8 @@ class StationaryModel(Model):
f' output_space: {self.output_space}\n' f' output_space: {self.output_space}\n'
) )
def _solve(self, mu=None, return_output=False): def _compute_solution(self, mu=None, **kwargs):
# explicitly checking if logging is disabled saves the str(mu) call return self.operator.apply_inverse(self.rhs.as_range_array(mu), mu=mu)
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
U = self.operator.apply_inverse(self.rhs.as_range_array(mu), mu=mu)
if return_output:
if self.output_functional is None:
raise ValueError('Model has no output')
return U, self.output_functional.apply(U, mu=mu)
else:
return U
class InstationaryModel(Model): class InstationaryModel(Model):
...@@ -195,21 +185,12 @@ class InstationaryModel(Model): ...@@ -195,21 +185,12 @@ class InstationaryModel(Model):
def with_time_stepper(self, **kwargs): def with_time_stepper(self, **kwargs):
return self.with_(time_stepper=self.time_stepper.with_(**kwargs)) return self.with_(time_stepper=self.time_stepper.with_(**kwargs))
def _solve(self, mu=None, return_output=False): def _compute_solution(self, mu=None, **kwargs):
# explicitly checking if logging is disabled saves the expensive str(mu) call
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
mu = mu.with_(t=0.) mu = mu.with_(t=0.)
U0 = self.initial_data.as_range_array(mu) U0 = self.initial_data.as_range_array(mu)
U = self.time_stepper.solve(operator=self.operator, rhs=self.rhs, initial_data=U0, mass=self.mass, U = self.time_stepper.solve(operator=self.operator, rhs=self.rhs, initial_data=U0, mass=self.mass,
initial_time=0, end_time=self.T, mu=mu, num_values=self.num_values) initial_time=0, end_time=self.T, mu=mu, num_values=self.num_values)
if return_output: return U
if self.output_functional is None:
raise ValueError('Model has no output')
return U, self.output_functional.apply(U, mu=mu)
else:
return U
def to_lti(self): def to_lti(self):
"""Convert model to |LTIModel|. """Convert model to |LTIModel|.
......
...@@ -8,6 +8,7 @@ from pymor.operators.constructions import induced_norm ...@@ -8,6 +8,7 @@ from pymor.operators.constructions import induced_norm
from pymor.parameters.base import ParametricObject, Mu from pymor.parameters.base import ParametricObject, Mu
from pymor.tools.frozendict import FrozenDict from pymor.tools.frozendict import FrozenDict
from pymor.tools.deprecated import Deprecated from pymor.tools.deprecated import Deprecated
from pymor.vectorarrays.interface import VectorArray
class Model(CacheableObject, ParametricObject): class Model(CacheableObject, ParametricObject):
...@@ -47,12 +48,95 @@ class Model(CacheableObject, ParametricObject): ...@@ -47,12 +48,95 @@ class Model(CacheableObject, ParametricObject):
self.__auto_init(locals()) self.__auto_init(locals())
@abstractmethod def _compute(self, solution=False, output=False,
def _solve(self, mu=None, return_output=False, **kwargs): solution_error_estimate=False, output_error_estimate=False,
"""Perform the actual solving.""" mu=None, **kwargs):
pass return {}
def solve(self, mu=None, return_output=False, **kwargs): def _compute_solution(self, mu=None, **kwargs):
raise NotImplementedError
def _compute_output(self, solution, mu=None, **kwargs):
if not hasattr(self, 'output_functional'):
raise NotImplementedError
if self.output_functional is None:
raise ValueError('Model has no output')
return self.output_functional.apply(solution, mu=mu)
def _compute_solution_error_estimate(self, solution, mu=None, **kwargs):
if self.error_estimator is None:
raise ValueError('Model has no error estimator')
return self.error_estimator.estimate_error(solution, mu, self)
def _compute_output_error_estimate(self, solution, mu=None, **kwargs):
if self.error_estimator is None:
raise ValueError('Model has no error estimator')
return self.error_estimator.estimate_output_error(solution, mu, self)
_compute_allowed_kwargs = frozenset()
def compute(self, solution=False, output=False,
solution_error_estimate=False, output_error_estimate=False, *,
mu=None, **kwargs):
# make sure no unknown kwargs are passed
assert kwargs.keys() <= self._compute_allowed_kwargs
# parse parameter values
if not isinstance(mu, Mu):
mu = self.parameters.parse(mu)
assert self.parameters.assert_compatible(mu)
# log output
# explicitly checking if logging is disabled saves some cpu cycles
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
# first call _compute to give subclasses more control
data = self._compute(solution=solution, output=output,
solution_error_estimate=solution_error_estimate,
output_error_estimate=output_error_estimate,
mu=mu, **kwargs)
if (solution or output or solution_error_estimate or output_error_estimate) and \
'solution' not in data:
retval = self.cached_method_call(self._compute_solution, mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'solution' in retval
data.update(retval)
else:
data['solution'] = retval
if output and 'output' not in data:
# TODO use caching here (requires skipping args in key generation)
retval = self._compute_output(data['solution'], mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'output' in retval
data.update(retval)
else:
data['output'] = retval
if solution_error_estimate and 'solution_error_estimate' not in data:
# TODO use caching here (requires skipping args in key generation)
retval = self._compute_solution_error_estimate(data['solution'], mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'solution_error_estimate' in retval
data.update(retval)
else:
data['solution_error_estimate'] = retval
if output_error_estimate and 'output_error_estimate' not in data:
# TODO use caching here (requires skipping args in key generation)
retval = self._compute_output_error_estimate(data['solution'], mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'output_error_estimate' in retval
data.update(retval)
else:
data['output_error_estimate'] = retval
return data
def solve(self, mu=None, return_error_estimate=False, **kwargs):
"""Solve the discrete problem for the |parameter values| `mu`. """Solve the discrete problem for the |parameter values| `mu`.
The result will be :mod:`cached <pymor.core.cache>` The result will be :mod:`cached <pymor.core.cache>`
...@@ -71,12 +155,18 @@ class Model(CacheableObject, ParametricObject): ...@@ -71,12 +155,18 @@ class Model(CacheableObject, ParametricObject):
The solution |VectorArray|. When `return_output` is `True`, The solution |VectorArray|. When `return_output` is `True`,
the output |VectorArray| is returned as second value. the output |VectorArray| is returned as second value.
""" """
if not isinstance(mu, Mu): data = self.compute(
mu = self.parameters.parse(mu) solution=True,
assert self.parameters.assert_compatible(mu) solution_error_estimate=return_error_estimate,
return self.cached_method_call(self._solve, mu=mu, return_output=return_output, **kwargs) mu=mu,
**kwargs
)
if return_error_estimate:
return data['solution'], data['solution_error_estimate']
else:
return data['solution']
def output(self, mu=None, **kwargs): def output(self, mu=None, return_error_estimate=False, **kwargs):
"""Return the model output for given |parameter values| `mu`. """Return the model output for given |parameter values| `mu`.
Parameters Parameters
...@@ -88,9 +178,18 @@ class Model(CacheableObject, ParametricObject): ...@@ -88,9 +178,18 @@ class Model(CacheableObject, ParametricObject):
------- -------
The computed model output as a |VectorArray| from `output_space`. The computed model output as a |VectorArray| from `output_space`.
""" """
return self.solve(mu=mu, return_output=True, **kwargs)[1] data = self.compute(
output=True,
output_error_estimate=return_error_estimate,
mu=mu,
**kwargs
)
if return_error_estimate:
return data['output'], data['output_error_estimate']
else:
return data['output']
def estimate_error(self, U, mu=None): def estimate_error(self, mu=None, **kwargs):
"""Estimate the model error for a given solution. """Estimate the model error for a given solution.
The model error could be the error w.r.t. the analytical The model error could be the error w.r.t. the analytical
...@@ -108,14 +207,39 @@ class Model(CacheableObject, ParametricObject): ...@@ -108,14 +207,39 @@ class Model(CacheableObject, ParametricObject):
------- -------
The estimated error. The estimated error.
""" """
if getattr(self, 'error_estimator') is not None: return self.compute(
return self.error_estimator.estimate_error(U, mu=mu, m=self) solution_error_estimate=True,
else: mu=mu,
raise NotImplementedError('Model has no error estimator.') **kwargs
)['solution_error_estimate']
@Deprecated('estimate_error') @Deprecated('estimate_error')
def estimate(self, U, mu=None): def estimate(self, U, mu=None):
return self.estimate_error(U, mu) return self.estimate_error(mu)
def estimate_output_error(self, mu=None, **kwargs):
"""Estimate the model error for a given solution.
The model error could be the error w.r.t. the analytical
solution of the given problem or the model reduction error w.r.t.
a corresponding high-dimensional |Model|.
Parameters
----------
U
The solution obtained by :meth:`~solve`.
mu
|Parameter values| for which `U` has been obtained.
Returns
-------
The estimated error.
"""
return self.compute(
output_error_estimate=True,
mu=mu,
**kwargs
)['output_error_estimate']
def visualize(self, U, **kwargs): def visualize(self, U, **kwargs):
"""Visualize a solution |VectorArray| U. """Visualize a solution |VectorArray| U.
......
...@@ -45,9 +45,6 @@ class InputOutputModel(Model): ...@@ -45,9 +45,6 @@ class InputOutputModel(Model):
def output_dim(self): def output_dim(self):
return self.output_space.dim return self.output_space.dim
def _solve(self, mu=None):
raise NotImplementedError
def eval_tf(self, s, mu=None): def eval_tf(self, s, mu=None):
"""Evaluate the transfer function.""" """Evaluate the transfer function."""
raise NotImplementedError raise NotImplementedError
......
...@@ -36,9 +36,9 @@ class MPIModel: ...@@ -36,9 +36,9 @@ class MPIModel:
self.parameters_internal = m.parameters_internal self.parameters_internal = m.parameters_internal
self.visualizer = MPIVisualizer(obj_id) self.visualizer = MPIVisualizer(obj_id)
def _solve(self, mu=None): def _compute_solution(self, mu=None, **kwargs):
return self.solution_space.make_array( return self.solution_space.make_array(
mpi.call(mpi.method_call_manage, self.obj_id, 'solve', mu=mu) mpi.call(mpi.method_call_manage, self.obj_id, '_compute_solution', mu=mu, **kwargs)
) )
def visualize(self, U, **kwargs): def visualize(self, U, **kwargs):
......
...@@ -62,9 +62,7 @@ if config.HAVE_TORCH: ...@@ -62,9 +62,7 @@ if config.HAVE_TORCH:
if output_functional is not None: if output_functional is not None:
self.output_space = output_functional.range self.output_space = output_functional.range
def _solve(self, mu=None, return_output=False): def _compute_solution(self, mu=None, **kwargs):
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
# convert the parameter `mu` into a form that is usable in PyTorch # convert the parameter `mu` into a form that is usable in PyTorch
converted_input = torch.from_numpy(mu.to_numpy()).double() converted_input = torch.from_numpy(mu.to_numpy()).double()
...@@ -73,13 +71,7 @@ if config.HAVE_TORCH: ...@@ -73,13 +71,7 @@ if config.HAVE_TORCH:
# convert plain numpy array to element of the actual solution space # convert plain numpy array to element of the actual solution space
U = self.solution_space.make_array(U) U = self.solution_space.make_array(U)
if return_output: return U
if self.output_functional is None:
raise ValueError('Model has no output')
return U, self.output_functional.apply(U, mu=mu)
else:
return U
class FullyConnectedNN(nn.Module, BasicObject): class FullyConnectedNN(nn.Module, BasicObject):
"""Class for neural networks with fully connected layers. """Class for neural networks with fully connected layers.
......
...@@ -72,10 +72,10 @@ def analyze_pickle_histogram(args): ...@@ -72,10 +72,10 @@ def analyze_pickle_histogram(args):
if hasattr(rom, 'estimate'): if hasattr(rom, 'estimate'):
ests = [] ests = []
for u, mu in zip(us, mus): for mu in mus:
print(f'Estimating error for {mu} ... ', end='') print(f'Estimating error for {mu} ... ', end='')
sys.stdout.flush() sys.stdout.flush()
ests.append(rom.estimate_error(u, mu=mu)) ests.append(rom.estimate_error(mu))
print('done') print('done')
if args['--detailed']: if args['--detailed']:
...@@ -212,10 +212,10 @@ def analyze_pickle_convergence(args): ...@@ -212,10 +212,10 @@ def analyze_pickle_convergence(args):
if hasattr(rom, 'estimate'): if hasattr(rom, 'estimate'):
ests = [] ests = []
start = time.time() start = time.time()
for u, mu in zip(us, mus): for mu in mus:
# print('e', end='') # print('e', end='')
# sys.stdout.flush() # sys.stdout.flush()
ests.append(rom.estimate_error(u, mu=mu)) ests.append(rom.estimate_error(mu))
ESTS.append(max(ests)) ESTS.append(max(ests))
T_ESTS.append((time.time() - start) * 1000. / len(mus)) T_ESTS.append((time.time() - start) * 1000. / len(mus))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment