Commit 524cfcbd authored by Stephan Rave's avatar Stephan Rave

[models] implement solve()/output()/error_estimate() in terms of compute()

parent c29a64d6
Pipeline #65598 failed with stages
in 45 minutes and 10 seconds
......@@ -333,9 +333,10 @@ def _compute_errors(mu, fom, reductor, error_estimator, error_norms, condition,
for i_N, N in enumerate(basis_sizes):
rom = reductor.reduce(dims={k: N for k in reductor.bases})
u = rom.solve(mu)
result = rom.compute(solution=True, solution_error_estimate=error_estimator, mu=mu)
u = result['solution']
if error_estimator:
e = rom.estimate_error(u, mu)
e = result['solution_error_estimate']
e = e[0] if hasattr(e, '__len__') else e
error_estimates[i_N] = e
if fom and reductor:
......
......@@ -268,7 +268,7 @@ def _rb_surrogate_evaluate(rom=None, fom=None, reductor=None, mus=None, error_no
return -1., None
if fom is None:
errors = [rom.estimate_error(rom.solve(mu), mu) for mu in mus]
errors = [rom.estimate_error(mu) for mu in mus]
elif error_norm is not None:
errors = [error_norm(fom.solve(mu) - reductor.reconstruct(rom.solve(mu))) for mu in mus]
else:
......
......@@ -84,18 +84,8 @@ class StationaryModel(Model):
f' output_space: {self.output_space}\n'
)
def _solve(self, mu=None, return_output=False):
# explicitly checking if logging is disabled saves the str(mu) call
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
U = self.operator.apply_inverse(self.rhs.as_range_array(mu), mu=mu)
if return_output:
if self.output_functional is None:
raise ValueError('Model has no output')
return U, self.output_functional.apply(U, mu=mu)
else:
return U
def _compute_solution(self, mu=None, **kwargs):
return self.operator.apply_inverse(self.rhs.as_range_array(mu), mu=mu)
class InstationaryModel(Model):
......@@ -195,21 +185,12 @@ class InstationaryModel(Model):
def with_time_stepper(self, **kwargs):
return self.with_(time_stepper=self.time_stepper.with_(**kwargs))
def _solve(self, mu=None, return_output=False):
# explicitly checking if logging is disabled saves the expensive str(mu) call
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
def _compute_solution(self, mu=None, **kwargs):
mu = mu.with_(t=0.)
U0 = self.initial_data.as_range_array(mu)
U = self.time_stepper.solve(operator=self.operator, rhs=self.rhs, initial_data=U0, mass=self.mass,
initial_time=0, end_time=self.T, mu=mu, num_values=self.num_values)
if return_output:
if self.output_functional is None:
raise ValueError('Model has no output')
return U, self.output_functional.apply(U, mu=mu)
else:
return U
return U
def to_lti(self):
"""Convert model to |LTIModel|.
......
......@@ -8,6 +8,7 @@ from pymor.operators.constructions import induced_norm
from pymor.parameters.base import ParametricObject, Mu
from pymor.tools.frozendict import FrozenDict
from pymor.tools.deprecated import Deprecated
from pymor.vectorarrays.interface import VectorArray
class Model(CacheableObject, ParametricObject):
......@@ -47,12 +48,95 @@ class Model(CacheableObject, ParametricObject):
self.__auto_init(locals())
@abstractmethod
def _solve(self, mu=None, return_output=False, **kwargs):
"""Perform the actual solving."""
pass
def _compute(self, solution=False, output=False,
solution_error_estimate=False, output_error_estimate=False,
mu=None, **kwargs):
return {}
def solve(self, mu=None, return_output=False, **kwargs):
def _compute_solution(self, mu=None, **kwargs):
raise NotImplementedError
def _compute_output(self, solution, mu=None, **kwargs):
if not hasattr(self, 'output_functional'):
raise NotImplementedError
if self.output_functional is None:
raise ValueError('Model has no output')
return self.output_functional.apply(solution, mu=mu)
def _compute_solution_error_estimate(self, solution, mu=None, **kwargs):
if self.error_estimator is None:
raise ValueError('Model has no error estimator')
return self.error_estimator.estimate_error(solution, mu, self)
def _compute_output_error_estimate(self, solution, mu=None, **kwargs):
if self.error_estimator is None:
raise ValueError('Model has no error estimator')
return self.error_estimator.estimate_output_error(solution, mu, self)
_compute_allowed_kwargs = frozenset()
def compute(self, solution=False, output=False,
solution_error_estimate=False, output_error_estimate=False, *,
mu=None, **kwargs):
# make sure no unknown kwargs are passed
assert kwargs.keys() <= self._compute_allowed_kwargs
# parse parameter values
if not isinstance(mu, Mu):
mu = self.parameters.parse(mu)
assert self.parameters.assert_compatible(mu)
# log output
# explicitly checking if logging is disabled saves some cpu cycles
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
# first call _compute to give subclasses more control
data = self._compute(solution=solution, output=output,
solution_error_estimate=solution_error_estimate,
output_error_estimate=output_error_estimate,
mu=mu, **kwargs)
if (solution or output or solution_error_estimate or output_error_estimate) and \
'solution' not in data:
retval = self.cached_method_call(self._compute_solution, mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'solution' in retval
data.update(retval)
else:
data['solution'] = retval
if output and 'output' not in data:
# TODO use caching here (requires skipping args in key generation)
retval = self._compute_output(data['solution'], mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'output' in retval
data.update(retval)
else:
data['output'] = retval
if solution_error_estimate and 'solution_error_estimate' not in data:
# TODO use caching here (requires skipping args in key generation)
retval = self._compute_solution_error_estimate(data['solution'], mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'solution_error_estimate' in retval
data.update(retval)
else:
data['solution_error_estimate'] = retval
if output_error_estimate and 'output_error_estimate' not in data:
# TODO use caching here (requires skipping args in key generation)
retval = self._compute_output_error_estimate(data['solution'], mu=mu, **kwargs)
if isinstance(retval, dict):
assert 'output_error_estimate' in retval
data.update(retval)
else:
data['output_error_estimate'] = retval
return data
def solve(self, mu=None, return_error_estimate=False, **kwargs):
"""Solve the discrete problem for the |parameter values| `mu`.
The result will be :mod:`cached <pymor.core.cache>`
......@@ -71,12 +155,18 @@ class Model(CacheableObject, ParametricObject):
The solution |VectorArray|. When `return_output` is `True`,
the output |VectorArray| is returned as second value.
"""
if not isinstance(mu, Mu):
mu = self.parameters.parse(mu)
assert self.parameters.assert_compatible(mu)
return self.cached_method_call(self._solve, mu=mu, return_output=return_output, **kwargs)
data = self.compute(
solution=True,
solution_error_estimate=return_error_estimate,
mu=mu,
**kwargs
)
if return_error_estimate:
return data['solution'], data['solution_error_estimate']
else:
return data['solution']
def output(self, mu=None, **kwargs):
def output(self, mu=None, return_error_estimate=False, **kwargs):
"""Return the model output for given |parameter values| `mu`.
Parameters
......@@ -88,9 +178,18 @@ class Model(CacheableObject, ParametricObject):
-------
The computed model output as a |VectorArray| from `output_space`.
"""
return self.solve(mu=mu, return_output=True, **kwargs)[1]
data = self.compute(
output=True,
output_error_estimate=return_error_estimate,
mu=mu,
**kwargs
)
if return_error_estimate:
return data['output'], data['output_error_estimate']
else:
return data['output']
def estimate_error(self, U, mu=None):
def estimate_error(self, mu=None, **kwargs):
"""Estimate the model error for a given solution.
The model error could be the error w.r.t. the analytical
......@@ -108,14 +207,39 @@ class Model(CacheableObject, ParametricObject):
-------
The estimated error.
"""
if getattr(self, 'error_estimator') is not None:
return self.error_estimator.estimate_error(U, mu=mu, m=self)
else:
raise NotImplementedError('Model has no error estimator.')
return self.compute(
solution_error_estimate=True,
mu=mu,
**kwargs
)['solution_error_estimate']
@Deprecated('estimate_error')
def estimate(self, U, mu=None):
return self.estimate_error(U, mu)
return self.estimate_error(mu)
def estimate_output_error(self, mu=None, **kwargs):
"""Estimate the model error for a given solution.
The model error could be the error w.r.t. the analytical
solution of the given problem or the model reduction error w.r.t.
a corresponding high-dimensional |Model|.
Parameters
----------
U
The solution obtained by :meth:`~solve`.
mu
|Parameter values| for which `U` has been obtained.
Returns
-------
The estimated error.
"""
return self.compute(
output_error_estimate=True,
mu=mu,
**kwargs
)['output_error_estimate']
def visualize(self, U, **kwargs):
"""Visualize a solution |VectorArray| U.
......
......@@ -45,9 +45,6 @@ class InputOutputModel(Model):
def output_dim(self):
return self.output_space.dim
def _solve(self, mu=None):
raise NotImplementedError
def eval_tf(self, s, mu=None):
"""Evaluate the transfer function."""
raise NotImplementedError
......
......@@ -36,9 +36,9 @@ class MPIModel:
self.parameters_internal = m.parameters_internal
self.visualizer = MPIVisualizer(obj_id)
def _solve(self, mu=None):
def _compute_solution(self, mu=None, **kwargs):
return self.solution_space.make_array(
mpi.call(mpi.method_call_manage, self.obj_id, 'solve', mu=mu)
mpi.call(mpi.method_call_manage, self.obj_id, '_compute_solution', mu=mu, **kwargs)
)
def visualize(self, U, **kwargs):
......
......@@ -62,9 +62,7 @@ if config.HAVE_TORCH:
if output_functional is not None:
self.output_space = output_functional.range
def _solve(self, mu=None, return_output=False):
if not self.logging_disabled:
self.logger.info(f'Solving {self.name} for {mu} ...')
def _compute_solution(self, mu=None, **kwargs):
# convert the parameter `mu` into a form that is usable in PyTorch
converted_input = torch.from_numpy(mu.to_numpy()).double()
......@@ -73,13 +71,7 @@ if config.HAVE_TORCH:
# convert plain numpy array to element of the actual solution space
U = self.solution_space.make_array(U)
if return_output:
if self.output_functional is None:
raise ValueError('Model has no output')
return U, self.output_functional.apply(U, mu=mu)
else:
return U
return U
class FullyConnectedNN(nn.Module, BasicObject):
"""Class for neural networks with fully connected layers.
......
......@@ -72,10 +72,10 @@ def analyze_pickle_histogram(args):
if hasattr(rom, 'estimate'):
ests = []
for u, mu in zip(us, mus):
for mu in mus:
print(f'Estimating error for {mu} ... ', end='')
sys.stdout.flush()
ests.append(rom.estimate_error(u, mu=mu))
ests.append(rom.estimate_error(mu))
print('done')
if args['--detailed']:
......@@ -212,10 +212,10 @@ def analyze_pickle_convergence(args):
if hasattr(rom, 'estimate'):
ests = []
start = time.time()
for u, mu in zip(us, mus):
for mu in mus:
# print('e', end='')
# sys.stdout.flush()
ests.append(rom.estimate_error(u, mu=mu))
ests.append(rom.estimate_error(mu))
ESTS.append(max(ests))
T_ESTS.append((time.time() - start) * 1000. / len(mus))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment