Unverified Commit 06b0dc4c authored by René Fritze's avatar René Fritze Committed by GitHub
Browse files

Merge pull request #1650 from pymor/issue_1649

[models.nn] fix output_functional not properly assigned
parents 19d044b4 c0bb6b06
Pipeline #146460 passed with stages
in 36 minutes and 16 seconds
......@@ -95,9 +95,9 @@ class NeuralNetworkModel(BaseNeuralNetworkModel):
self.__auto_init(locals())
self.solution_space = NumpyVectorSpace(neural_network.output_dimension)
output_functional = output_functional or ZeroOperator(NumpyVectorSpace(0), self.solution_space)
assert output_functional.source == self.solution_space
self.dim_output = output_functional.range.dim
self.output_functional = output_functional or ZeroOperator(NumpyVectorSpace(0), self.solution_space)
assert self.output_functional.source == self.solution_space
self.dim_output = self.output_functional.range.dim
def _compute_solution(self, mu=None, **kwargs):
......
......@@ -2,24 +2,13 @@
# Copyright pyMOR developers and contributors. All rights reserved.
# License: BSD 2-Clause License (https://opensource.org/licenses/BSD-2-Clause)
import os
import pytest
import numpy as np
def _skip_if_no_torch():
try:
import torch # NOQA
except ImportError as ie:
if not os.environ.get('DOCKER_PYMOR', False):
pytest.skip('skipped test due to missing Torch')
raise ie
from pymortests.base import skip_if_missing
@skip_if_missing('TORCH')
def test_linear_function_fitting():
_skip_if_no_torch()
from pymor.reductors.neural_network import multiple_restarts_training
from pymor.models.neural_network import FullyConnectedNN
......@@ -67,8 +56,8 @@ def test_linear_function_fitting():
assert all(loss < tol for loss in best_losses.values())
@skip_if_missing('TORCH')
def test_no_training_data():
_skip_if_no_torch()
from pymor.reductors.neural_network import multiple_restarts_training
from pymor.models.neural_network import FullyConnectedNN
......@@ -83,3 +72,13 @@ def test_no_training_data():
neural_network = FullyConnectedNN([d_in, 3 * (d_in + d_out), 3 * (d_in + d_out), d_out]).double()
best_neural_network, _ = multiple_restarts_training(training_data, validation_data, neural_network)
assert np.allclose(best_neural_network(torch.DoubleTensor(np.random.rand(n, d_in))).detach(), np.zeros(d_out))
@skip_if_missing('TORCH')
def test_issue_1649():
from pymor.models.neural_network import FullyConnectedNN
from pymor.models.neural_network import NeuralNetworkModel
from pymor.parameters.base import Parameters
neural_network = FullyConnectedNN([3, 3, 3, 3]).double()
nn_model = NeuralNetworkModel(neural_network, Parameters(mu=1))
assert nn_model.output_functional
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment