# -*- coding: UTF-8 -*-

import json
import unittest
import contextlib
import tempfile
import os
from unittest.mock import patch
from ansible.module_utils import basic  # type: ignore
from ansible.module_utils.common.text.converters import to_bytes  # type: ignore


class AnsibleExitJson(Exception):
    """Exception class to be raised by module.exit_json and caught by the test case"""
    pass


class AnsibleFailJson(Exception):
    """Exception class to be raised by module.fail_json and caught by the test case"""
    pass


def exit_json(*args, **kwargs):
    """function to patch over exit_json; package return data into an exception"""
    if 'changed' not in kwargs:
        kwargs['changed'] = False
    raise AnsibleExitJson(kwargs)


def fail_json(*args, **kwargs):
    """function to patch over fail_json; package return data into an exception"""
    kwargs['failed'] = True
    raise AnsibleFailJson(kwargs)


class TestModuleBase(unittest.TestCase):

    def setUp(self):
        # save and restore working directory
        original_dir = os.getcwd()
        self.addCleanup(lambda: os.chdir(original_dir))

        self.mock_module_helper = patch.multiple(basic.AnsibleModule,  # @UndefinedVariable
                                                 exit_json=exit_json,
                                                 fail_json=fail_json)
        self.mock_module_helper.start()
        self.addCleanup(self.mock_module_helper.stop)

    def set_module_args(self, args):
        def restore_args():
            basic._ANSIBLE_ARGS = {}
        """prepare arguments so that they will be picked up during module creation"""
        json_args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
        basic._ANSIBLE_ARGS = to_bytes(json_args)
        self.addCleanup(restore_args)

    def setup_tmp_dir(self):
        with contextlib.ExitStack() as stack:
            tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
            self.addCleanup(stack.pop_all().close)
        return tmpdir

    def write_file(self, filepath: str, content: str):
        with open(filepath, 'w') as f:
            f.write(content)

    @contextlib.contextmanager
    def in_dir(self, directory):
        old_dir = os.getcwd()
        os.chdir(directory)
        try:
            yield
        finally:
            os.chdir(old_dir)
