diff --git a/ml_collections/config_flags/config_flags.py b/ml_collections/config_flags/config_flags.py index ccf1394..e17a349 100644 --- a/ml_collections/config_flags/config_flags.py +++ b/ml_collections/config_flags/config_flags.py @@ -20,11 +20,12 @@ import enum import errno import functools as ft -import imp +import importlib.machinery import os import re import sys import traceback +import types from typing import Any, Callable, Dict, Generic, List, MutableMapping, Optional, Sequence, Tuple, Type, TypeVar from absl import flags @@ -32,6 +33,7 @@ from ml_collections import config_dict from ml_collections.config_flags import config_path from ml_collections.config_flags import tuple_parser + FLAGS = flags.FLAGS # Forward for backwards compatibility. @@ -43,6 +45,20 @@ flags._helpers.disclaim_module_ids.add(id(sys.modules[__name__])) # pylint: disable=protected-access +def _load_source(module_name: str, module_path: str) -> types.ModuleType: + """Loads a Python module from its source file. + + Args: + module_name: name of the module in sys.modules. + module_path: path to the Python file containing the module. + + Returns: + The loaded Python module. + """ + loader = importlib.machinery.SourceFileLoader(module_name, module_path) + return loader.load_module() + + class _LiteralParser(flags.ArgumentParser): """Parse arbitrary built-in (`--cfg.val=1`, `--cfg.val="[1, 2, {}]"`,...).""" @@ -560,7 +576,7 @@ def _LoadConfigModule(name: str, path: str): # Works for relative paths. with ignoring_errors.Attempt('Relative path', path): - config_module = imp.load_source(name, path) + config_module = _load_source(name, path) return config_module # Nothing worked. Log the paths that were attempted.