From 7a05ed69b6d1462a59c4c7a9a31ab38abbccb515 Mon Sep 17 00:00:00 2001 From: Douglas Raillard Date: Thu, 7 Oct 2021 18:30:52 +0100 Subject: [PATCH] exekall: Fix standalone module handling passed on CLI FIX Fix invocations such as: exekall run lisa foo --conf target_conf.yml where "foo" is a simple python file, not inside any specific package. --- tools/exekall/exekall/_utils.py | 19 ++++++++++++++++++- tools/exekall/exekall/utils.py | 7 ++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tools/exekall/exekall/_utils.py b/tools/exekall/exekall/_utils.py index 6b934c2b5..0165337be 100644 --- a/tools/exekall/exekall/_utils.py +++ b/tools/exekall/exekall/_utils.py @@ -924,6 +924,23 @@ def get_subclasses(cls): return subcls_set +def get_package(module): + """ + Find the package name of a module. If the module has no package, its own + name is taken. + """ + # We keep the search local to the packages these modules are defined in, to + # avoid getting a huge set of uninteresting callables. + package = module.__package__ + if package: + name = package.split('.', 1)[0] + else: + name = None + # Standalone modules that are not inside a package will get "" as + # package name + return name or module.__name__ + + def get_recursive_module_set(module_set, package_set, visited_module_set=None): """ Retrieve the set of all modules recursively imported from the modules in @@ -938,7 +955,7 @@ def get_recursive_module_set(module_set, package_set, visited_module_set=None): return any( # Either a submodule of one of the packages or one of the # packages themselves - module.__name__.split('.', 1)[0] == package + get_package(module) == package for package in package_set ) diff --git a/tools/exekall/exekall/utils.py b/tools/exekall/exekall/utils.py index 369961ed5..efd940d5b 100644 --- a/tools/exekall/exekall/utils.py +++ b/tools/exekall/exekall/utils.py @@ -37,11 +37,8 @@ def get_callable_set(module_set, verbose=False): :param module_set: Set of modules to scan. :type module_set: set(types.ModuleType) """ - # We keep the search local to the packages these modules are defined in, to - # avoid getting a huge set of uninteresting callables. - package_set = { - module.__package__.split('.', 1)[0] for module in module_set - } + + package_set = set(map(get_package, module_set)) callable_set = set() visited_obj_set = set() visited_module_set = set() -- GitLab