"""Implementation of the config class, which manages the configuration of different Bittensor modules.
Example:
    import argparse
    import bittensor as bt
    parser = argparse.ArgumentParser('Miner')
    bt.Axon.add_args(parser)
    bt.Subtensor.add_args(parser)
    bt.Async_subtensor.add_args(parser)
    bt.Wallet.add_args(parser)
    bt.logging.add_args(parser)
    bt.PriorityThreadPoolExecutor.add_args(parser)
    config = bt.config(parser)
    print(config)
"""
import argparse
import os
import sys
from copy import deepcopy
from typing import Any, TypeVar, Type, Optional
import yaml
from munch import DefaultMunch
def _filter_keys(obj):
    """Filters keys from an object, excluding private and certain internal properties."""
    if isinstance(obj, dict):
        return {
            k: _filter_keys(v)
            for k, v in obj.items()
            if not k.startswith("__") and not k.startswith("_Config__is_set")
        }
    elif isinstance(obj, (Config, DefaultMunch)):
        return _filter_keys(obj.toDict())
    return obj
[docs]
class InvalidConfigFile(Exception):
    """Raised when there's an error loading the config file.""" 
[docs]
class Config(DefaultMunch):
    """Manages configuration for Bittensor modules with nested namespace support."""
    def __init__(
        self,
        parser: argparse.ArgumentParser = None,
        args: Optional[list[str]] = None,
        strict: bool = False,
        default: Any = None,
    ) -> None:
        super().__init__(default)
        self.__is_set = {}
        if parser is None:
            return
        self._add_default_arguments(parser)
        args = args or sys.argv[1:]
        self._validate_required_args(parser, args)
        config_params = self._parse_args(args, parser, strict=False)
        config_path = self._get_config_path(config_params)
        strict = strict or getattr(config_params, "strict", False)
        if config_path:
            self._load_config_file(parser, config_path)
        params = self._parse_args(args, parser, strict)
        self._build_config_tree(params)
        self._detect_set_parameters(parser, args)
    def __str__(self) -> str:
        """String representation without private keys, optimized to avoid deepcopy."""
        cleaned = _filter_keys(self.toDict())
        return "\n" + yaml.dump(cleaned, sort_keys=False, default_flow_style=False)
    def __repr__(self) -> str:
        """String representation of the Config."""
        return self.__str__()
    def _validate_required_args(
        self, parser: argparse.ArgumentParser, args: list[str]
    ) -> None:
        """Validates required arguments are present."""
        missing = self._find_missing_required_args(parser, args)
        if missing:
            raise ValueError(f"Missing required arguments: {', '.join(missing)}")
    def _find_missing_required_args(
        self, parser: argparse.ArgumentParser, args: list[str]
    ) -> list[str]:
        """Identifies missing required arguments."""
        required = {a.dest for a in parser._actions if a.required}
        provided = {a.split("=")[0].lstrip("-") for a in args if a.startswith("-")}
        return list(required - provided)
    def _get_config_path(self, params: DefaultMunch) -> Optional[str]:
        """Gets Config path from parameters."""
        return getattr(params, "config", None)
    def _load_config_file(self, parser: argparse.ArgumentParser, path: str) -> None:
        """Loads Config from YAML file."""
        try:
            with open(os.path.expanduser(path)) as f:
                config = yaml.safe_load(f)
                print(f"Loading config from: {path}")
                parser.set_defaults(**config)
        except Exception as e:
            raise InvalidConfigFile(f"Error loading config: {e}") from e
    def _build_config_tree(self, params: DefaultMunch) -> None:
        """Builds nested Config structure."""
        for key, value in params.items():
            if key in ["__is_set"]:
                continue
            current = self
            parts = key.split(".")
            for part in parts[:-1]:
                current = current.setdefault(part, Config())
            current[parts[-1]] = value
    def _detect_set_parameters(
        self, parser: argparse.ArgumentParser, args: list[str]
    ) -> None:
        """Detects which parameters were explicitly set."""
        temp_parser = self._create_non_default_parser(parser)
        detected = self._parse_args(args, temp_parser, strict=False)
        self.__is_set = DefaultMunch(**{k: True for k in detected.keys()})
    def _create_non_default_parser(
        self, original: argparse.ArgumentParser
    ) -> argparse.ArgumentParser:
        """Creates a parser that ignores default values."""
        parser = deepcopy(original)
        for action in parser._actions:
            action.default = argparse.SUPPRESS
        return parser
    @staticmethod
    def _parse_args(
        args: list[str], parser: argparse.ArgumentParser, strict: bool
    ) -> DefaultMunch:
        """Parses args with error handling."""
        try:
            if strict:
                result = parser.parse_args(args)
                return DefaultMunch.fromDict(vars(result))
            result, unknown = parser.parse_known_args(args)
            for arg in unknown:
                if arg.startswith("--") and (name := arg[2:]) in vars(result):
                    setattr(result, name, True)
            return DefaultMunch.fromDict(vars(result))
        except Exception:
            raise ValueError("Invalid arguments provided.")
    def __deepcopy__(self, memo) -> "Config":
        """Creates a deep copy that maintains Config type."""
        new_config = Config()
        memo[id(self)] = new_config
        for key, value in self.items():
            new_config[key] = deepcopy(value, memo)
        new_config.__is_set = deepcopy(self.__is_set, memo)
        return new_config
[docs]
    def merge(self, other: "Config") -> None:
        """Merges another Config into this one."""
        self.update(self._merge_dicts(self, other))
        self.__is_set.update(other.__is_set) 
    @staticmethod
    def _merge_dicts(a: DefaultMunch, b: DefaultMunch) -> DefaultMunch:
        """Recursively merges two Config objects."""
        result = deepcopy(a)
        for key, value in b.items():
            if key in result:
                if isinstance(result[key], DefaultMunch) and isinstance(
                    value, DefaultMunch
                ):
                    result[key] = Config._merge_dicts(result[key], value)
                else:
                    result[key] = deepcopy(value)
            else:
                result[key] = deepcopy(value)
        return result
[docs]
    def is_set(self, param_name: str) -> bool:
        """Checks if a parameter was explicitly set."""
        return self.__is_set.get(param_name, False) 
[docs]
    def to_dict(self) -> dict:
        """Returns the configuration as a dictionary."""
        return self.toDict() 
    def _add_default_arguments(self, parser: argparse.ArgumentParser) -> None:
        """Adds default arguments to the Config parser."""
        arguments = [
            (
                "--config",
                {
                    "type": str,
                    "help": "If set, defaults are overridden by passed file.",
                    "default": False,
                },
            ),
            (
                "--strict",
                {
                    "action": "store_true",
                    "help": "If flagged, config will check that only exact arguments have been set.",
                    "default": False,
                },
            ),
            (
                "--no_version_checking",
                {
                    "action": "store_true",
                    "help": "Set `true to stop cli version checking.",
                    "default": False,
                },
            ),
        ]
        for arg_name, kwargs in arguments:
            try:
                parser.add_argument(arg_name, **kwargs)
            except argparse.ArgumentError:
                # this can fail if argument has already been added.
                pass 
T = TypeVar("T", bound="DefaultConfig")
[docs]
class DefaultConfig(Config):
    """A Config with a set of default values."""
[docs]
    @classmethod
    def default(cls: Type[T]) -> T:
        """Get default config."""
        raise NotImplementedError("Function default is not implemented.")