#!/usr/bin/env python
""" helper for encrypting/decrypting Cfg and Properties files """

import os
import sys
import copy
import select
import logging
import lxml.etree
import Bcfg2.Logger
import Bcfg2.Options
from Bcfg2.Server import XMLParser
from Bcfg2.Compat import input  # pylint: disable=W0622
try:
    import Bcfg2.Encryption
except ImportError:
    print("Could not import %s. Is M2Crypto installed?" % sys.exc_info()[1])
    raise SystemExit(1)


class EncryptionChunkingError(Exception):
    """ error raised when Encryptor cannot break a file up into chunks
    to be encrypted, or cannot reassemble the chunks """
    pass


class Encryptor(object):
    """ Generic encryptor for all files """

    def __init__(self, setup):
        self.setup = setup
        self.passphrase = None
        self.pname = None
        self.logger = logging.getLogger(self.__class__.__name__)

    def get_encrypted_filename(self, plaintext_filename):
        """ get the name of the file encrypted data should be written to """
        return plaintext_filename

    def get_plaintext_filename(self, encrypted_filename):
        """ get the name of the file decrypted data should be written to """
        return encrypted_filename

    def chunk(self, data):
        """ generator to break the file up into smaller chunks that
        will each be individually encrypted or decrypted """
        yield data

    def unchunk(self, data, original):  # pylint: disable=W0613
        """ given chunks of a file, reassemble then into the whole file """
        try:
            return data[0]
        except IndexError:
            raise EncryptionChunkingError("No data to unchunk")

    def set_passphrase(self):
        """ set the passphrase for the current file """
        if (not self.setup.cfp.has_section(Bcfg2.Encryption.CFG_SECTION) or
            len(Bcfg2.Encryption.get_passphrases(self.setup)) == 0):
            self.logger.error("No passphrases available in %s" %
                              self.setup['configfile'])
            return False

        if self.passphrase:
            self.logger.debug("Using previously determined passphrase %s" %
                              self.pname)
            return True

        if self.setup['passphrase']:
            self.pname = self.setup['passphrase']

        if self.pname:
            if self.setup.cfp.has_option(Bcfg2.Encryption.CFG_SECTION,
                                         self.pname):
                self.passphrase = \
                    self.setup.cfp.get(Bcfg2.Encryption.CFG_SECTION,
                                       self.pname)
                self.logger.debug("Using passphrase %s specified on command "
                                  "line" % self.pname)
                return True
            else:
                self.logger.error("Could not find passphrase %s in %s" %
                                  (self.pname, self.setup['configfile']))
                return False
        else:
            pnames = Bcfg2.Encryption.get_passphrases(self.setup)
            if len(pnames) == 1:
                self.pname = pnames.keys()[0]
                self.passphrase = pnames[self.pname]
                self.logger.info("Using passphrase %s" % self.pname)
                return True
            elif len(pnames) > 1:
                self.logger.warning("Multiple passphrases found in %s, "
                                    "specify one on the command line with -p" %
                                    self.setup['configfile'])
        self.logger.info("No passphrase could be determined")
        return False

    def encrypt(self, fname):
        """ encrypt the given file, returning the encrypted data """
        try:
            plaintext = open(fname).read()
        except IOError:
            err = sys.exc_info()[1]
            self.logger.error("Error reading %s, skipping: %s" % (fname, err))
            return False

        if not self.set_passphrase():
            return False

        crypted = []
        try:
            for chunk in self.chunk(plaintext):
                try:
                    passphrase, pname = self.get_passphrase(chunk)
                except TypeError:
                    return False

                crypted.append(self._encrypt(chunk, passphrase, name=pname))
        except EncryptionChunkingError:
            err = sys.exc_info()[1]
            self.logger.error("Error getting data to encrypt from %s: %s" %
                              (fname, err))
            return False
        return self.unchunk(crypted, plaintext)

    #  pylint: disable=W0613
    def _encrypt(self, plaintext, passphrase, name=None):
        """ encrypt a single chunk of a file """
        return Bcfg2.Encryption.ssl_encrypt(
            plaintext, passphrase,
            Bcfg2.Encryption.get_algorithm(self.setup))
    #  pylint: enable=W0613

    def decrypt(self, fname):
        """ decrypt the given file, returning the plaintext data """
        try:
            crypted = open(fname).read()
        except IOError:
            err = sys.exc_info()[1]
            self.logger.error("Error reading %s, skipping: %s" % (fname, err))
            return False

        self.set_passphrase()

        plaintext = []
        try:
            for chunk in self.chunk(crypted):
                try:
                    passphrase, pname = self.get_passphrase(chunk)
                    try:
                        plaintext.append(self._decrypt(chunk, passphrase))
                    except Bcfg2.Encryption.EVPError:
                        self.logger.info("Could not decrypt %s with the "
                                         "specified passphrase" % fname)
                        continue
                    except:
                        err = sys.exc_info()[1]
                        self.logger.error("Error decrypting %s: %s" %
                                          (fname, err))
                        continue
                except TypeError:
                    pchunk = None
                    passphrases = Bcfg2.Encryption.get_passphrases(self.setup)
                    for pname, passphrase in passphrases.items():
                        self.logger.debug("Trying passphrase %s" % pname)
                        try:
                            pchunk = self._decrypt(chunk, passphrase)
                            break
                        except Bcfg2.Encryption.EVPError:
                            pass
                        except:
                            err = sys.exc_info()[1]
                            self.logger.error("Error decrypting %s: %s" %
                                              (fname, err))
                    if pchunk is not None:
                        plaintext.append(pchunk)
                    else:
                        self.logger.error("Could not decrypt %s with any "
                                          "passphrase in %s" %
                                          (fname, self.setup['configfile']))
                        continue
        except EncryptionChunkingError:
            err = sys.exc_info()[1]
            self.logger.error("Error getting encrypted data from %s: %s" %
                              (fname, err))
            return False

        try:
            return self.unchunk(plaintext, crypted)
        except EncryptionChunkingError:
            err = sys.exc_info()[1]
            self.logger.error("Error assembling plaintext data from %s: %s" %
                              (fname, err))
            return False

    def _decrypt(self, crypted, passphrase):
        """ decrypt a single chunk """
        return Bcfg2.Encryption.ssl_decrypt(
            crypted, passphrase,
            Bcfg2.Encryption.get_algorithm(self.setup))

    def write_encrypted(self, fname, data=None):
        """ write encrypted data to disk """
        if data is None:
            data = self.decrypt(fname)
        new_fname = self.get_encrypted_filename(fname)
        try:
            open(new_fname, "wb").write(data)
            self.logger.info("Wrote encrypted data to %s" % new_fname)
            return True
        except IOError:
            err = sys.exc_info()[1]
            self.logger.error("Error writing encrypted data from %s to %s: %s"
                              % (fname, new_fname, err))
            return False
        except EncryptionChunkingError:
            err = sys.exc_info()[1]
            self.logger.error("Error assembling encrypted data from %s: %s" %
                              (fname, err))
            return False

    def write_decrypted(self, fname, data=None):
        """ write decrypted data to disk """
        if data is None:
            data = self.decrypt(fname)
        new_fname = self.get_plaintext_filename(fname)
        try:
            open(new_fname, "wb").write(data)
            self.logger.info("Wrote decrypted data to %s" % new_fname)
            return True
        except IOError:
            err = sys.exc_info()[1]
            self.logger.error("Error writing encrypted data from %s to %s: %s"
                              % (fname, new_fname, err))
            return False

    def get_passphrase(self, chunk):
        """ get the passphrase for a chunk of a file """
        pname = self._get_passphrase(chunk)
        if not self.pname:
            if not pname:
                self.logger.info("No passphrase given on command line or "
                                 "found in file")
                return False
            elif self.setup.cfp.has_option(Bcfg2.Encryption.CFG_SECTION,
                                           pname):
                passphrase = self.setup.cfp.get(Bcfg2.Encryption.CFG_SECTION,
                                                pname)
            else:
                self.logger.error("Could not find passphrase %s in %s" %
                                  (pname, self.setup['configfile']))
                return False
        else:
            pname = self.pname
            passphrase = self.passphrase
            if self.pname != pname:
                self.logger.warning("Passphrase given on command line (%s) "
                                    "differs from passphrase embedded in "
                                    "file (%s), using command-line option" %
                                    (self.pname, pname))
        return (passphrase, pname)

    def _get_passphrase(self, chunk):  # pylint: disable=W0613
        """ get the passphrase for a chunk of a file """
        return None


class CfgEncryptor(Encryptor):
    """ encryptor class for Cfg files """

    def get_encrypted_filename(self, plaintext_filename):
        return plaintext_filename + ".crypt"

    def get_plaintext_filename(self, encrypted_filename):
        if encrypted_filename.endswith(".crypt"):
            return encrypted_filename[:-6]
        else:
            return Encryptor.get_plaintext_filename(self, encrypted_filename)


class PropertiesEncryptor(Encryptor):
    """ encryptor class for Properties files """

    def _encrypt(self, plaintext, passphrase, name=None):
        # plaintext is an lxml.etree._Element
        if name is None:
            name = "true"
        if plaintext.text and plaintext.text.strip():
            plaintext.text = Bcfg2.Encryption.ssl_encrypt(
                plaintext.text,
                passphrase,
                Bcfg2.Encryption.get_algorithm(self.setup)).strip()
            plaintext.set("encrypted", name)
        return plaintext

    def chunk(self, data):
        xdata = lxml.etree.XML(data, parser=XMLParser)
        if self.setup['xpath']:
            elements = xdata.xpath(self.setup['xpath'])
            if not elements:
                raise EncryptionChunkingError("XPath expression %s matched no "
                                              "elements" % self.setup['xpath'])
        else:
            elements = xdata.xpath('//*[@encrypted]')
            if not elements:
                elements = list(xdata.getiterator(tag=lxml.etree.Element))

        # filter out elements without text data
        for el in elements[:]:
            if not el.text:
                elements.remove(el)

        if self.setup['interactive']:
            for element in elements[:]:
                if len(element):
                    elt = copy.copy(element)
                    for child in elt.iterchildren():
                        elt.remove(child)
                else:
                    elt = element
                print(lxml.etree.tostring(
                    elt,
                    xml_declaration=False).decode("UTF-8").strip())
                # flush input buffer
                while len(select.select([sys.stdin.fileno()], [], [],
                                        0.0)[0]) > 0:
                    os.read(sys.stdin.fileno(), 4096)
                ans = input("Encrypt this element? [y/N] ")
                if not ans.lower().startswith("y"):
                    elements.remove(element)

        # this is not a good use of a generator, but we need to
        # generate the full list of elements in order to ensure that
        # some exist before we know what to return
        for elt in elements:
            yield elt

    def unchunk(self, data, original):
        # Properties elements are modified in-place, so we don't
        # actually need to unchunk anything
        xdata = data[0]
        # find root element
        while xdata.getparent() is not None:
            xdata = xdata.getparent()
        return lxml.etree.tostring(xdata,
                                   xml_declaration=False,
                                   pretty_print=True).decode('UTF-8')

    def _get_passphrase(self, chunk):
        pname = chunk.get("encrypted")
        if pname and pname.lower() != "true":
            return pname
        return None

    def _decrypt(self, crypted, passphrase):
        # crypted is in lxml.etree._Element
        if not crypted.text or not crypted.text.strip():
            self.logger.warning("Skipping empty element %s" % crypted.tag)
            return crypted
        decrypted = Bcfg2.Encryption.ssl_decrypt(
            crypted.text,
            passphrase,
            Bcfg2.Encryption.get_algorithm(self.setup)).strip()
        try:
            crypted.text = decrypted.encode('ascii', 'xmlcharrefreplace')
        except UnicodeDecodeError:
            # we managed to decrypt the value, but it contains content
            # that can't even be encoded into xml entities.  what
            # probably happened here is that we coincidentally could
            # decrypt a value encrypted with a different key, and
            # wound up with gibberish.
            self.logger.warning("Decrypted %s to gibberish, skipping" %
                                crypted.tag)
        return crypted


def main():  # pylint: disable=R0912,R0915
    optinfo = dict(interactive=Bcfg2.Options.INTERACTIVE)
    optinfo.update(Bcfg2.Options.CRYPT_OPTIONS)
    optinfo.update(Bcfg2.Options.CLI_COMMON_OPTIONS)
    setup = Bcfg2.Options.OptionParser(optinfo)
    setup.hm = "     bcfg2-crypt [options] <filename>\nOptions:\n%s" % \
        setup.buildHelpMessage()
    setup.parse(sys.argv[1:])

    if not setup['args']:
        print(setup.hm)
        raise SystemExit(1)

    log_args = dict(to_syslog=setup['syslog'], to_console=logging.WARNING)
    if setup['verbose']:
        log_args['to_console'] = logging.DEBUG
    Bcfg2.Logger.setup_logging('bcfg2-crypt', **log_args)
    logger = logging.getLogger('bcfg2-crypt')

    if setup['decrypt']:
        if setup['encrypt']:
            logger.error("You cannot specify both --encrypt and --decrypt")
            raise SystemExit(1)
        elif setup['remove']:
            logger.error("--remove cannot be used with --decrypt, ignoring")
            setup['remove'] = Bcfg2.Options.CRYPT_REMOVE.default
        elif setup['interactive']:
            logger.error("Cannot decrypt interactively")
            setup['interactive'] = False

    if setup['cfg']:
        if setup['properties']:
            logger.error("You cannot specify both --cfg and --properties")
            raise SystemExit(1)
        if setup['xpath']:
            logger.error("Specifying --xpath with --cfg is nonsensical, "
                         "ignoring --xpath")
            setup['xpath'] = Bcfg2.Options.CRYPT_XPATH.default
        if setup['interactive']:
            logger.error("You cannot use interactive mode with --cfg, "
                         "ignoring -I")
            setup['interactive'] = False
    elif setup['properties']:
        if setup['remove']:
            logger.error("--remove cannot be used with --properties, ignoring")
            setup['remove'] = Bcfg2.Options.CRYPT_REMOVE.default

    props_crypt = PropertiesEncryptor(setup)
    cfg_crypt = CfgEncryptor(setup)

    for fname in setup['args']:
        if not os.path.exists(fname):
            logger.error("%s does not exist, skipping" % fname)
            continue

        # figure out if we need to encrypt this as a Properties file
        # or as a Cfg file
        props = False
        if setup['properties']:
            props = True
        elif setup['cfg']:
            props = False
        elif fname.endswith(".xml"):
            try:
                xroot = lxml.etree.parse(fname).getroot()
                if xroot.tag == "Properties":
                    props = True
                else:
                    props = False
            except IOError:
                err = sys.exc_info()[1]
                logger.error("Error reading %s, skipping: %s" % (fname, err))
                continue
            except lxml.etree.XMLSyntaxError:
                props = False
        else:
            props = False

        if props:
            encryptor = props_crypt
            if setup['remove']:
                logger.info("Cannot use --remove with Properties file %s, "
                            "ignoring for this file" % fname)
        else:
            if setup['xpath']:
                logger.info("Cannot use xpath with Cfg file %s, ignoring "
                            "xpath for this file" % fname)
            if setup['interactive']:
                logger.info("Cannot use interactive mode with Cfg file %s, "
                            "ignoring -I for this file" % fname)
            encryptor = cfg_crypt

        data = None
        if setup['encrypt']:
            xform = encryptor.encrypt
            write = encryptor.write_encrypted
        elif setup['decrypt']:
            xform = encryptor.decrypt
            write = encryptor.write_decrypted
        else:
            logger.info("Neither --encrypt nor --decrypt specified, "
                        "determining mode")
            data = encryptor.decrypt(fname)
            if data:
                write = encryptor.write_decrypted
            else:
                logger.info("Failed to decrypt %s, trying encryption" % fname)
                data = None
                xform = encryptor.encrypt
                write = encryptor.write_encrypted

        if data is None:
            data = xform(fname)
        if not data:
            logger.error("Failed to %s %s, skipping" % (xform.__name__, fname))
            continue
        if setup['crypt_stdout']:
            if len(setup['args']) > 1:
                print("----- %s -----" % fname)
            print(data)
            if len(setup['args']) > 1:
                print("")
        else:
            write(fname, data=data)

        if (setup['remove'] and
            encryptor.get_encrypted_filename(fname) != fname):
            try:
                os.unlink(fname)
            except IOError:
                err = sys.exc_info()[1]
                logger.error("Error removing %s: %s" % (fname, err))
                continue

if __name__ == '__main__':
    sys.exit(main())
