#!/usr/bin/python3

import argparse
import subprocess
import os
import glob
import shutil

from typing import List

from ansibleext.facade import AnsibleLibraryFacade  # type: ignore
from cassandra import cassandra_executor  # type: ignore


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--source-host', '-sh', type=str, required=True, help='The hostname of the source cassandra.')
    parser.add_argument('--destination-host', '-dh', type=str, required=True, help='The hostname of the destination '
                                                                                   'cassandra.')
    parser.add_argument('--source-inventory', '-si', type=str, required=True, help='The inventory of the source host. '
                                                                                   'Possible values: '
                                                                                   'inventory-testing, '
                                                                                   'inventory-devel,'
                                                                                   'inventory-stable')
    parser.add_argument('--source-keyspace', '-sk', type=str, required=True, help='The name of the source keyspace')
    parser.add_argument('--destination-keyspace', '-dk', type=str, required=True, help='The name of the destination '
                                                                                       'keyspace')
    parser.add_argument('--snapshot', '-s', type=str, required=True, help='The name of the snapshot')
    parser.add_argument('--local-path', '-lp', type=str, required=True, help='The folder, where the snapshot '
                                                                             'will be stored (local)')
    parser.add_argument('--tables', '-t', type=str, nargs='*', help='The requested tables. All the tables must be in '
                                                                    'the source keyspace. If no tables are specified,'
                                                                    ' the snapshot will contain the whole keyspace')
    parser.add_argument('--force-new-keyspace', action='store_true', help='Force to create new keyspace. The script '
                                                                          'will fail if the requested keyspace '
                                                                          'already exists.')
    args = parser.parse_args()
    return args


def get_ip_from_hostname(hostname: str, inventory: str) -> str:
    ansible_library_facade = AnsibleLibraryFacade('infrastructure')
    host = ansible_library_facade.load_inventory(inventory).get_variables(hostname)
    return host['primary_ip']


def check_ssh_connection(ip: str) -> None:
    command = 'echo checking ssh connection to {}'.format(ip)
    result = execute_command_via_ssh(ip, command)
    if result != '':
        print('SSH connection SUCCEEDED to {}'.format(ip))
    else:
        print('SSH connection FAILED to {}'.format(ip))
        exit(1)


def check_cqlsh_connection(ip: str) -> None:
    print('checking CQLSH connection to {}'.format(ip))
    try:
        with cassandra_executor.CassandraExecutor(ip) as executor:
            print('CQLSH connection SUCCEEDED to {}'.format(ip))
    except Exception as e:
        print('CQLSH connection FAILED to {}, error: {}'.format(ip, str(e)))
        exit(1)


def create_local_snapshot_directory(args: argparse.Namespace) -> None:
    snapshot_path = os.path.join(args.local_path, args.snapshot)
    print('Creating snapshot path directory on local machine: {}'.format(snapshot_path))
    if not os.path.isdir(snapshot_path):
        os.makedirs(snapshot_path)
        print('Snapshot path directory created')
        return
    user_input = input('Snapshot path directory already exists on local machine. Do you want to continue? (Yes/anything else): ')
    if user_input != 'Yes':
        print('Interrupted by user')
        exit(1)


def generate_snapshot_command(args: argparse.Namespace) -> str:
    command = 'nodetool snapshot -t {}'.format(args.snapshot)
    if args.tables is None or len(args.tables) == 0:
        command += ' {}'.format(args.source_keyspace)
    else:
        command += ' -kt {}.{}'.format(args.source_keyspace, args.tables[0])
        for table in args.tables[1:]:
            command += ',{}.{}'.format(args.source_keyspace, table)
    print('Snapshot command: {}'.format(command))
    return command


def execute_command_via_ssh(ip: str, command: str) -> str:
    print('Executing command on {}: {}'.format(ip, command))
    ssh = subprocess.Popen(['ssh', 'root@{}'.format(ip), command], shell=False, stdout=subprocess.PIPE,
                           stderr=subprocess.PIPE)
    stdout = ssh.stdout
    if stdout is None:
        raise ValueError('ssh command stdout should not be None')
    result = stdout.read().decode('utf-8')
    print('Executed command result: {}'.format(result))
    return result


def get_tables(args: argparse.Namespace, keyspace) -> List[str]:
    if args.tables is None or len(args.tables) == 0:
        return list(keyspace.tables.keys())
    return args.tables


def get_table_descriptions_with_new_keyspace(ip: str, args: argparse.Namespace) -> List[str]:
    described_tables = []
    print('creating cassandra executor with ip {}'.format(ip))
    with cassandra_executor.CassandraExecutor(ip) as executor:
        print('created cassandra executor with ip {}'.format(ip))
        if args.source_keyspace not in executor.cluster.metadata.keyspaces:
            raise ValueError('requested keyspace does not exist: {}'.format(args.source_keyspace))
        keyspace = executor.cluster.metadata.keyspaces[args.source_keyspace]
        print('gathering table descriptions')
        for table in get_tables(args, keyspace):
            if table not in keyspace.tables:
                raise ValueError('requested table does not exist: {}'.format(table))
            metadata = keyspace.tables[table]
            metadata_str = metadata.export_as_string()
            metadata_str = metadata_str.replace(args.source_keyspace,
                                                'IF NOT EXISTS {}'.format(args.destination_keyspace))
            described_tables.append(metadata_str)
        print('gathered table descriptions')
    print('destroyed cassandra executor with ip {}'.format(ip))
    return described_tables


def create_tables_from_description(ip: str, table_descriptions: List[str], args: argparse.Namespace) -> None:
    print('creating cassandra executor with ip {}'.format(ip))
    with cassandra_executor.CassandraExecutor(ip) as executor:
        print('created cassandra executor with ip {}'.format(ip))
        if args.force_new_keyspace:
            print('Creating new keyspace')
            executor.execute(
                'CREATE KEYSPACE ' + args.destination_keyspace + " WITH replication = {'class': "
                                                                 "'NetworkTopologyStrategy', 'BPSIBRIKMAIN': 3};",
                60)
            print('New keyspace created')
        print('executing table creation commands')
        for description in table_descriptions:
            executor.execute(description, 60)
        print('executed table creation commands')
    print('destroyed cassandra executor with ip {}'.format(ip))


def get_requested_directories_to_copy(ip: str, args: argparse.Namespace) -> List[str]:
    directories = []
    for table in args.tables:
        command = 'find \'/mnt/db/cassandra/{}/\' -type d \( -name "{}" -o -name "{}-*" \) -printf %P\n'.format(args.source_keyspace, table, table)
        result = execute_command_via_ssh(ip, command)
        directories.append(result)
    if not directories:
        print("No matching directories found")
    else:
        print('directories: {}'.format(directories))
    return directories


def execute_keyspace_copy_with_rsync(ip: str, args: argparse.Namespace) -> None:
    source = 'root@{}:/mnt/db/cassandra'.format(ip)
    command = ['rsync', '-av', '--progress', '--include=/{}/*/'.format(args.source_keyspace),
               '--include=/{}/*/snapshots'.format(args.source_keyspace),
               '--include=/{}/*/snapshots/{}/'.format(args.source_keyspace, args.snapshot),
               '--include=/{}/*/snapshots/{}/**'.format(args.source_keyspace, args.snapshot),
               '--exclude=/{}/**'.format(args.source_keyspace),
               '{source}/{keyspace}'.format(source=source, keyspace=args.source_keyspace),
               '{local_path}/{snapshot}'.format(local_path=args.local_path, snapshot=args.snapshot)]
    subprocess.run(command, check=True)


def execute_tables_copy_with_rsync(ip: str, directories: List[str], args: argparse.Namespace) -> None:
    source = 'root@{}:/mnt/db/cassandra'.format(ip)
    command = ['rsync', '-av', '--progress']
    for directory in directories:
        command.append('--include=/{}/{}/'.format(args.source_keyspace, directory))
        command.append('--include=/{}/{}/snapshots/'.format(args.source_keyspace, directory))
        command.append('--include=/{}/{}/snapshots/{}/'.format(args.source_keyspace, directory, args.snapshot))
        command.append('--include=/{}/{}/snapshots/{}/**'.format(args.source_keyspace, directory, args.snapshot))
    command.append('--exclude=/{}/**'.format(args.source_keyspace))
    command.append('{source}/{keyspace}'.format(source=source, keyspace=args.source_keyspace))
    command.append('{local_path}/{snapshot}'.format(local_path=args.local_path, snapshot=args.snapshot))
    print(command)
    subprocess.run(command, check=True)


def copy_snapshot(ip: str, args: argparse.Namespace) -> None:
    source_keyspace = args.source_keyspace
    destination_keyspace = args.destination_keyspace
    snapshot_path = os.path.join(args.local_path, args.snapshot)
    source_keyspace_directory = '{}/{}'.format(snapshot_path, source_keyspace)
    destination_keyspace_directory = '{}/{}'.format(snapshot_path, destination_keyspace)
    print('Copying snapshot for source keyspace {}'.format(source_keyspace))
    if not os.path.isdir(source_keyspace_directory):
        if os.path.isdir(destination_keyspace_directory):
            user_input = input('Looks like the snapshot has been already copied and directory renamed ({} directory exists). Do you want to continue? (Yes/anything else):'.format(destination_keyspace_directory))
            if user_input != 'Yes':
                print('Interrupted by user')
                exit(1)
            else:
                return
    else:
        if os.path.isdir(destination_keyspace_directory):
            raise ValueError('Both source and destination keyspace directories are exist')
    if args.tables is None or len(args.tables) == 0:
        execute_keyspace_copy_with_rsync(ip, args)
    else:
        directories = get_requested_directories_to_copy(ip, args)
        execute_tables_copy_with_rsync(ip, directories, args)


def rename_local_snapshot_folder(args: argparse.Namespace) -> None:
    source_keyspace = args.source_keyspace
    destination_keyspace = args.destination_keyspace
    snapshot_path = os.path.join(args.local_path, args.snapshot)
    source_directory = '{}/{}'.format(snapshot_path, source_keyspace)
    destination_directory = '{}/{}'.format(snapshot_path, destination_keyspace)
    print('Renaming local snapshot folder from {} to {} (Basically, changing the keyspace in the path from {} to {}. This is required to use the data with sstableloader)'.format(source_directory, destination_directory, source_keyspace, destination_keyspace))
    if not os.path.isdir(source_directory):
        if os.path.isdir(destination_directory):
            user_input = input('Looks like the renaming has been already finished (source directory is missing, destination directory exists). Do you want to continue? (Yes/anything else):')
            if user_input != 'Yes':
                print('Interrupted by user')
                exit(1)
            else:
                return
        else:
            raise ValueError('Both source and destination directories are missing')
    else:
        if os.path.isdir(destination_directory):
            raise ValueError('Both source and destination directories are exist')
    os.chdir(snapshot_path)
    os.rename(source_directory, destination_directory)
    print('Local snapshot folder renamed successfully from {} to {}'.format(source_directory, destination_directory))


def move_snapshot_files(args: argparse.Namespace) -> None:
    destination_keyspace = args.destination_keyspace
    print('Moving files from {}/<table>/snapshots/{}/ to {}/<table>/ (Basically, out from the snapshots directories. This is required to use the data with sstableloader)'.format(destination_keyspace, args.snapshot, destination_keyspace))
    snapshot_path = os.path.join(args.local_path, args.snapshot)
    os.chdir(snapshot_path)
    data = glob.glob(os.path.join(args.destination_keyspace, "**/*"), recursive=True)
    files = [f for f in data if os.path.isfile(f)]
    tables = set()
    for file in files:
        splitted = file.split('/')
        if (len(splitted) < 5):
            continue
        table_name = splitted[-4]
        tables.add(table_name)
        destination = splitted[0:-3]
        destination.append(splitted[len(splitted) - 1])
        destination_path = '/'.join(destination)
        shutil.move(file, destination_path)
    if len(tables) == 0:
        print('There are no files to move or the files already moved.')
        return
    print('Files moved successfully for tables {}'.format(tables))


def load_snapshot_to_destination(ip: str, args: argparse.Namespace):
    print('Loading data with sstableloader to {}'.format(ip))
    snapshot_path = os.path.join(args.local_path, args.snapshot)
    os.chdir(snapshot_path)
    data = glob.glob('{}/*'.format(args.destination_keyspace))
    absolute_path = os.path.abspath(os.getcwd())
    for directory in data:
        command = ['sstableloader', '--nodes', ip, directory]
        print('Loading directory {} using {} ...'.format(directory, command))
        subprocess.run(command, cwd=absolute_path, check=True)
        print('... done')
    print('Loading data finished successfully')


def check_described_tables_if_needed(described_tables: List[str], args: argparse.Namespace) -> None:
    if args.tables is not None and len(args.tables) > 0:
        return
    print('Found tables:')
    for table in described_tables:
        print(table)
    user_input = input('Do you want to copy all the tables above? (Yes/anything else): ')
    if user_input != 'Yes':
        print('Interrupted by user')
        exit(1)


def main():
    args = parse_args()
    source_ip = get_ip_from_hostname(args.source_host, args.source_inventory)
    destination_inventory = 'inventory-testing'
    destination_ip = get_ip_from_hostname(args.destination_host, destination_inventory)
    check_ssh_connection(source_ip)
    check_ssh_connection(destination_ip)
    check_cqlsh_connection(source_ip)
    check_cqlsh_connection(destination_ip)
    snapshot_command = generate_snapshot_command(args)
    execute_command_via_ssh(source_ip, snapshot_command)
    described_tables = get_table_descriptions_with_new_keyspace(source_ip, args)
    check_described_tables_if_needed(described_tables, args)
    create_tables_from_description(destination_ip, described_tables, args)
    create_local_snapshot_directory(args)
    copy_snapshot(source_ip, args)
    rename_local_snapshot_folder(args)
    move_snapshot_files(args)
    load_snapshot_to_destination(destination_ip, args)


main()
