diff --git a/dune/xt/common/bindings.cc b/dune/xt/common/bindings.cc index eb10bdfdfebeb1e688cbdc744d40e405c51e983b..e6820feb5fe6aa33120eec19e1230513134646ce 100644 --- a/dune/xt/common/bindings.cc +++ b/dune/xt/common/bindings.cc @@ -16,6 +16,8 @@ #include <boost/numeric/conversion/cast.hpp> +#include <dune/xt/common/exceptions.hh> + #include <dune/common/parallel/mpihelper.hh> #if HAVE_DUNE_FEM @@ -33,29 +35,28 @@ #include "fmatrix.pbh" #include "configuration.pbh" -namespace py = pybind11; -using namespace pybind11::literals; - PYBIND11_PLUGIN(_common) { + namespace py = pybind11; + using namespace pybind11::literals; + py::module m("_common", "dune-xt-common"); Dune::XT::Common::bind_Exception(m); - m.def("init_mpi", + m.def("_init_mpi", [](const std::vector<std::string>& args) { int argc = boost::numeric_cast<int>(args.size()); char** argv = Dune::XT::Common::vector_to_main_args(args); + Dune::MPIHelper::instance(argc, argv); #if HAVE_DUNE_FEM Dune::Fem::MPIManager::initialize(argc, argv); -#else - Dune::MPIHelper::instance(argc, argv); #endif }, "args"_a = std::vector<std::string>()); - m.def("init_logger", + m.def("_init_logger", [](const ssize_t max_info_level, const ssize_t max_debug_level, const bool enable_warnings, diff --git a/python/dune/xt/common/__init__.py b/python/dune/xt/common/__init__.py index 52805a53d0549db4a6ae0c0f5c79680ee8d033e9..23387f6a9002082573c6963c9a0f13384c0a55d2 100644 --- a/python/dune/xt/common/__init__.py +++ b/python/dune/xt/common/__init__.py @@ -14,6 +14,18 @@ try: except ImportError: pass +_init_logger_methods = list() +_init_mpi_methods = list() +_other_modules = ('xt.la', 'xt.grid', 'xt.functions', 'gdt') + +from ._common import __dict__ as module +to_import = [name for name in module if not name.startswith('_')] +globals().update({name: module[name] for name in to_import}) +_init_logger_methods.append(module['_init_logger']) +_init_mpi_methods.append(module['_init_mpi']) +del to_import +del module + def init_logger(max_info_level=-1, max_debug_level=-1, @@ -22,26 +34,27 @@ def init_logger(max_info_level=-1, info_color='blue', debug_color='darkgray', warning_color='red'): - from ._common import init_logger as _init_logger - initializers = [_init_logger] - for module_name in ('xt.la', 'xt.grid', 'xt.functions', 'gdt'): + init_logger_methods = _init_logger_methods.copy() + for module_name in _other_modules: try: mm = import_module('dune.{}'.format(module_name)) - initializers.append(mm.init_logger) + for init_logger_method in mm._init_logger_methods: + init_logger_methods.append(init_logger_method) except ModuleNotFoundError: pass - for initializer in initializers: - initializer(max_info_level, max_debug_level, enable_warnings, enable_colors, info_color, debug_color, - warning_color) - + for init_logger_method in init_logger_methods: + init_logger_method(max_info_level, max_debug_level, enable_warnings, enable_colors, info_color, debug_color, + warning_color) def init_mpi(args=list()): - try: - from dune.gdt import init_mpi as _init_mpi - except ModuleNotFoundError: - from dune.xt.common import init_mpi as _init_mpi - _init_mpi(args) - - -from ._common import * + init_mpi_methods = _init_mpi_methods.copy() + for module_name in _other_modules: + try: + mm = import_module('dune.{}'.format(module_name)) + for init_mpi_method in mm._init_mpi_methods: + init_mpi_methods.append(init_mpi_method) + except ModuleNotFoundError: + pass + for init_mpi_method in init_mpi_methods: + init_mpi_method(args)