# -*- coding: UTF-8 -*-

import glob
import re
import itertools
import ansibleext.general_utils as utils
from ansibleext.inventory_base import InventoryBase
from typing import Dict, Any


def intersect_dicts(d1: Dict[Any, Any], d2: Dict[Any, Any]):
    return {k: v for k, v in d1.items() if k in d2.keys()}


class InventoryLoadable(InventoryBase):

    def runTest(self):
        inventory_files = glob.glob('inventory-*')
        self.assertIn('inventory-stable', inventory_files)
        for inventory_file in inventory_files:
            inventory = self.inventory(inventory_file)
            self.assertIsNotNone(inventory)


class InventoryAnsibleVersionIsSameAsInTests(InventoryBase):

    def runTest(self):
        inventories = self.inventories()
        for inventory in inventories:
            ansible_core_version = inventory.get_ansible_version()
            hosts = inventory.get_hosts('*')
            for host in hosts:
                ansible_core_accepted_version = inventory.get_variables(host.get_name())['ansible_accepted_version']['core']['full']
                # because ansible-core is accepted in a range e.g: for ansible 10.5.0, 2.17 > ansible-core >= 2.17
                self.assertIn(ansible_core_accepted_version, ansible_core_version)


class InventoriesDoNotContainInconsistentVars(InventoryBase):

    def runTest(self):
        inventories = self.inventories()
        inventories_hosts = {i: i.get_hosts('*') for i in inventories}
        inventories_hosts_vars = {inv: self.get_hosts_vars(inv, hosts) for (inv, hosts) in inventories_hosts.items()}
        ignored_keys = ['gpsrepo_from_env', 'group_names', 'groups', 'omit', 'inventory_file', 'inventory_dir']
        ignored_hosts = ['repo-main', 'localhost']
        for (i1, i2) in itertools.combinations(inventories, 2):
            hosts1 = inventories_hosts_vars[i1]
            hosts2 = inventories_hosts_vars[i2]
            for host in utils.intersection(hosts1, hosts2):
                keys1 = {k: v for k, v in hosts1[host].items() if k not in ignored_keys and host not in ignored_hosts}
                keys2 = {k: v for k, v in hosts2[host].items() if k not in ignored_keys and host not in ignored_hosts}
                differing_keys1 = utils.dict_key_intersection_value_difference(keys1, keys2)
                differing_keys2 = {k: keys2[k] for k in differing_keys1}
                self.assertDictEqual(differing_keys1, differing_keys2,
                    f'\nKeys differ in inventory pair ({i1}, {i2}) for host {host}')

    def get_hosts_vars(self, inventory, hosts):
        return {host.get_name(): inventory.get_variables(host.get_name()) for host in hosts}


class InventoryHasUniqueSshIps(InventoryBase):

    def runTest(self):
        ips_to_hostname = {}
        for inventory, host in self.inventories_hosts():
            if self.isHostInAnyGroups(inventory, host, ('employee', 'office_hwn', 'management_all', 'testing_win_hwn', 'testing_win_virtual', 'ungrouped', 'embedded_linuxes')):
                continue
            ip = self.verifyPropertyExists(inventory, host, 'ansible_host')
            utils.add_to_map_of_sets(ips_to_hostname, ip, host.get_name())
        duplicates = utils.duplicated_elements(ips_to_hostname)
        self.assertDictEqual(duplicates, {}, 'Duplicate ansible_host IPs detected for different hosts')


class InventoryHasUniqueHostnames(InventoryBase):

    def runTest(self):
        hostname_to_hosts_name = {}
        for inventory, host in self.inventories_hosts():
            if self.isHostInAnyGroups(inventory, host, ('employee', 'office_hwn', 'management_all', 'testing_win_hwn', 'testing_win_virtual', 'ungrouped', 'embedded_linuxes')):
                continue
            hostname = self.verifyPropertyExists(inventory, host, 'hostname')
            utils.add_to_map_of_sets(hostname_to_hosts_name, hostname, host.get_name())
        duplicates = utils.duplicated_elements(hostname_to_hosts_name)
        self.assertDictEqual(duplicates, {}, 'Duplicate hostnames detected for different hosts')


class InventoryHasUniquePrimaryIps(InventoryBase):

    def runTest(self):
        ips_to_hostname = {}
        for inventory, host in self.inventories_hosts():
            ip = self.getVariable(inventory, host, 'primary_ip')
            utils.add_to_map_of_sets(ips_to_hostname, ip, host.get_name())
        duplicates = utils.duplicated_elements(ips_to_hostname)
        self.assertDictEqual(duplicates, {}, 'Duplicate primary_ip IPs detected for different hosts')


class InventoryHasCorrectIpsInVariousNetworks(InventoryBase):

    NETWORK_TO_IP = {
        'sibrik_net': r'^192\.168\.(221|222|231|0|19[23456789]|57)\.[0-9]+$',
        'extended_devnet': r'^192\.168\.19[23456789]\.[0-9]+$',
        'kozma_net': r'^192\.168\.20[89]\.[0-9]+$',
        'ilka_net': r'^192\.168\.211\.[0-9]+$',
        'public_net': r'^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$',
        'some_vpn_net': r'^10\.66\.[0-9]+\.[0-9]+$',
        'hetzner_hosting': r'^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$',
        'zerris_hosting': r'^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$',
    }
    NETWORK_TO_IP_DONTUSE = {
        'sibrik_net': r'^192\.168\.221\.17[0-9]$',
    }

    SKIPPABLE_TESTING_HOSTS = ['cardserver-devel1', 'cardserver-devel2']

    def network_to_ip(self):
        return InventoryHasCorrectIpsInVariousNetworks.NETWORK_TO_IP

    def assert_ip_not_prohibited(self, ip, group, inventory, host):
        regex = InventoryHasCorrectIpsInVariousNetworks.NETWORK_TO_IP_DONTUSE.get(group)
        if regex is None:
            return
        self.assertNotRegex(ip, regex, f'IP address {ip} inconsistent in network group {group}: {inventory}, {host}')

    def runTest(self):
        for inventory, host in self.inventories_hosts():
            self.checkForGroup(inventory, host)
            #self.checkForExtendedDevnet(inventory, host)
        for inventory, host in self.inventories_hosts_onlystable():
            group_memberships = self.checkForGroup(inventory, host)
            host_name = host.get_name()
            if "extended_devnet" in group_memberships:
                self.assertTrue("sibrik_net" in group_memberships,
                    f'Host {host_name} is a member of extended_devnet, but not a member of sibrik_net. Extended_devnet should be a child of sibrik_net.')
                self.assertEqual(len(group_memberships), 2,
                    f'Host {host_name} is a member of other network groups besides extended_devnet and sibrik_net: {group_memberships}')
            else:
                self.assertEqual(len(group_memberships), 1,
                    f'Host {host_name} is a member of multiple network groups: {group_memberships}')

    def checkForGroup(self, inventory, host):
        ssh_ip = self.getVariable(inventory, host, 'ansible_host')
        primary_ip = self.getVariable(inventory, host, 'primary_ip')
        group_memberships = []
        for group in self.getGroups(inventory, host):
            if group in self.network_to_ip().keys():
                group_memberships.append(group)
                regex = self.network_to_ip()[group]
                if ssh_ip:
                    self.assertRegex(ssh_ip, regex,
                        f'IP address {ssh_ip} inconsistent with network group {group}: {inventory}, {host}')
                    self.assert_ip_not_prohibited(ssh_ip, group, inventory, host)
                if primary_ip:
                    self.assertRegex(primary_ip, regex,
                        f'IP address {ssh_ip} inconsistent with network group {group}: {inventory}, {host}')
                    self.assert_ip_not_prohibited(primary_ip, group, inventory, host)
        return group_memberships

    def checkForExtendedDevnet(self, inventory, host):
        if self.checkForSkippableTestingHosts(inventory, host):
            return
        ssh_ip = self.getVariable(inventory, host, 'ansible_host')
        primary_ip = self.getVariable(inventory, host, 'primary_ip')
        extended_devnet_pattern = re.compile(self.network_to_ip()['extended_devnet'])
        if self.matchesAny(extended_devnet_pattern, [ssh_ip, primary_ip]):
            self.assertIn('extended_devnet', self.getGroups(inventory, host), f'Host should be in extended_devnet: {host}')

    def matchesAny(self, pattern, strings):
        for s in strings:
            if s and pattern.match(s):
                return True
        return False

    def checkForSkippableTestingHosts(self, inventory, host):
        if host.get_name() in InventoryHasCorrectIpsInVariousNetworks.SKIPPABLE_TESTING_HOSTS:
            if 'inventory-testing' in inventory.get_name():
                return True
        return False


class InventoryParentChildRelationshipConsistent(InventoryBase):

    def runTest(self):
        for inventory, host in self.inventories_hosts_onlystable():
            parent = self.getVariable(inventory, host, 'parent')
            if parent is None or parent == 'beton':
                continue
            if self.isHostManuallySetup(inventory, host):
                continue
            variables = self.getVariablesForHostName(inventory, parent)
            self.assertIn('hwn', variables['group_names'], f"Host {host}'s parent {parent} is not a hw node?")
            self.assertIn(host.get_name(), variables.values(),
                f'{host} has parent {parent}, but its parent has no value referring to {host}')


class InventoryHardwareGroupsConsistent(InventoryBase):

    def runTest(self):
        for inventory, host in self.inventories_hosts():
            excludedGroups = ('office_hwn', 'gyartas_hwn', 'itrack_dispatcher', 'modemserver', 'firewall',
                            'cardserver', 'employee', 'debugvpn', 'cadvisor')
            if self.isHostInAnyGroups(inventory, host, excludedGroups):
                continue
            if not self.groupNamesContain(inventory, host, 'hwn'):
                continue
            groups = self.getGroups(inventory, host);
            for group in groups:
                allowedGroups = ('idata',
                                'testing', 'devel', 'stable',
                                'kozma_net', 'sibrik_net', 'ilka_net', 'extended_devnet', 'public_net', 'beton_net',
                                'bacula_all', 'bacula_backup',
                                'jenkins', 'jenkins_fedtest', 'jenkins_tablettest',
                                'hw_test_automation',
                                'promtail',
                                'nonclusterspecific',
                                'bond_detect_mode_carrier')
                if group in allowedGroups:
                    continue
                contains = utils.containsAtLeastOneSubstring(group, ('hwn', 'lxc', 'openvz'));
                self.assertTrue(contains, f'{host} is a hwn node but found it under illegal group {group}')


class GroupNameContainsAllowedCharacters(InventoryBase):

    def runTest(self):
        incorrectlyNamedGroups = set()
        for groupName in self.inventories_groups_onlystable():
            if not re.search('^[a-z0-9_]*$', groupName):
                incorrectlyNamedGroups.add(groupName)
        numberOfIncorrectlyNamedGroups = len(incorrectlyNamedGroups)
        self.assertEqual(numberOfIncorrectlyNamedGroups, 0, f'There are "{numberOfIncorrectlyNamedGroups}" incorrectly named groups: {incorrectlyNamedGroups}')


class HostsPartOfAtMostOneCluster(InventoryBase):

    def runTest(self):
        for inventory, host in self.inventories_hosts_onlystable():
            clusterServicesGroups = self.getGroupsWithNameMatching(inventory, host, r"^[^_\s]+(_services)$", ('logistics_services',))
            hostInClustersNum = len(clusterServicesGroups)
            self.assertLessEqual(hostInClustersNum, 1,
                f"{host} can't be part of multiple iTrack clusters (member of multiple <clustername>_services group). Found {hostInClustersNum} host(s): {clusterServicesGroups}")


class LxcHostHasValidPreferences(InventoryBase):

    def runTest(self):
        for inventory, host in self.inventories_hosts_onlystable():
            groups = self.getGroups(inventory, host)
            lxc_zfsroot_type = self.getVariable(inventory, host, 'lxc_zfsroot_type')
            lxc_zfsroot = self.getVariable(inventory, host, 'lxc_zfsroot')
            self.assertIsNone(lxc_zfsroot, f'lxc_zfsroot is deprecated, but used in {host}')
            if 'hwn' in groups and 'lxc' in groups:
                lxc_zfs_vol = self.getVariable(inventory, host, 'lxc_zfs_vol')
                self.assertIsNotNone(lxc_zfs_vol, f'no lxc_zfs_vol configured for {host}')
                lxc_zfsroots = self.getVariable(inventory, host, 'lxc_zfsroots')
                self.assertIsNotNone(lxc_zfsroots,
                    f'no lxc_zfsroots configured for parent {host}, but it is an LXC HWN')
            if lxc_zfsroot_type:
                parent_host = self.getVariable(inventory, host, 'parent')
                if parent_host is not None:
                    if parent_host == 'beton' or parent_host == 'dispdevel' or parent_host == 'dispatcher':
                        continue
                    lxc_zfsroots = self.getVariableForName(inventory, parent_host, 'lxc_zfsroots')
                    lxc_zfs_vol = self.getVariableForName(inventory, parent_host, 'lxc_zfs_vol')
                    self.assertIsNotNone(lxc_zfsroots,
                        f'no lxc_zfsroots configured for parent {parent_host} of {host}')
                    self.assertTrue(lxc_zfsroot_type in lxc_zfsroots,
                        f'on parent {parent_host}, lxc_zfsroots must contain {lxc_zfsroot_type} key because {host} refers to it')
                    self.assertDictContainsItems(lxc_zfsroots, {'generic': lxc_zfs_vol},
                        f'on parent {parent_host}, lxc_zfsroots must contain lxc_zfs_vol={lxc_zfs_vol} value under generic key')
            else:
                parent_host = self.getVariable(inventory, host, 'parent')
                if parent_host:
                    self.fail(f'no lxc_zfsroot_type configured for {host}, but it has parent {parent_host}')
