diff --git a/python/dune/xt/__init__.py b/python/dune/xt/__init__.py index 4fa4f424925bf6081280389712470c7125388aca..dafa999ab48b6a2e062ea9a0575acf88d36072e7 100644 --- a/python/dune/xt/__init__.py +++ b/python/dune/xt/__init__.py @@ -33,6 +33,19 @@ _init_logger_calls = set() _test_logger_calls = set() +def _register_special_funcs(mod, base_name=''): + if isinstance(mod, dict): + mod_dict = mod + else: + mod_dict = import_module('.{}'.format(mod), base_name).__dict__ + if '_init_mpi' in mod_dict: + _init_mpi_calls.add(mod_dict['_init_mpi']) + if '_init_logger' in mod_dict: + _init_logger_calls.add(mod_dict['_init_logger']) + if '_test_logger' in mod_dict: + _test_logger_calls.add(mod_dict['_test_logger']) + + def guarded_import(globs, base_name, mod_name): # see https://stackoverflow.com/questions/43059267/how-to-do-from-module-import-using-importlib try: @@ -42,12 +55,7 @@ def guarded_import(globs, base_name, mod_name): else: names = [x for x in mod.__dict__ if not x.startswith("_")] # import special init functions which should be present in every module - if '_init_mpi' in mod.__dict__: - _init_mpi_calls.add(mod.__dict__['_init_mpi']) - if '_init_logger' in mod.__dict__: - _init_logger_calls.add(mod.__dict__['_init_logger']) - if '_test_logger' in mod.__dict__: - _test_logger_calls.add(mod.__dict__['_test_logger']) + _register_special_funcs(mod.__dict__) # check the rest for duplicity for nm in names: if nm in globs: