#!/usr/bin/env python3

import locale
import argparse
import os
import io
import re
try:
    import readline
except ImportError:
    pass
from stat import S_ISDIR, S_ISREG
import sys
import tempfile

from chardet.universaldetector import UniversalDetector


VERSION = "1.6.1"


# Adapted from: https://stackoverflow.com/questions/24528278/stream-multiple-files-into-a-readable-object-in-python
# Note: the original code is licensed under CC-BY-SA 3.0, which is
# upwards-compatible with 4.0, and hence compatible with GPLv3.
class ChainStream(io.RawIOBase):
    """
    Chain an iterable of streams together into a single buffered stream.
    Usage:
        def generate_open_file_streams():
            for file in filenames:
                yield open(file, 'rb')
        f = io.BufferedReader(ChainStream(generate_open_file_streams()))
        f.read()
    """
    def __init__(self, streams):
        self.leftover = b''
        self.stream_iter = iter(streams)
        try:
            self.stream = next(self.stream_iter)
        except StopIteration:
            self.stream = None

    def readable(self):
        return True

    def _read_next_chunk(self, max_length):
        # Return 0 or more bytes from the current stream, first returning all
        # leftover bytes. If the stream is closed returns b''
        if self.leftover:
            return self.leftover
        elif self.stream is not None:
            return self.stream.read(max_length)
        else:
            return b''

    def readinto(self, b):
        buffer_length = len(b)
        chunk = self._read_next_chunk(buffer_length)
        while len(chunk) == 0:
            # move to next stream
            if self.stream is not None:
                self.stream.close()
            try:
                self.stream = next(self.stream_iter)
                chunk = self._read_next_chunk(buffer_length)
            except StopIteration:
                # No more streams to chain together
                self.stream = None
                return 0  # indicate EOF
        output, self.leftover = chunk[:buffer_length], chunk[buffer_length:]
        b[:len(output)] = output
        return len(output)


def warn(message):
    print("\nrpl: {}".format(message), file=sys.stderr)

def get_files(filenames, verbose, hidden_files):
    """Yield (filename, permissions) pairs."""
    for filename in filenames:
        try:
            perms = os.lstat(filename)
        except OSError as e:
            warn("SKIPPING {}: unable to read permissions; error: {}".format(filename, e))
            continue

        if S_ISDIR(perms.st_mode):
            if verbose:
                warn("SKIPPING directory {}".format(filename))
            continue
        elif S_ISREG(perms.st_mode):
            yield filename, perms
        else:
            warn("SKIPPING: {} (not a regular file)".format(filename))


def unescape(s):
    regex = re.compile(r'\\([0-7]{1,3}|x[0-9a-fA-F]{2}|[nrtvafb\\])')
    return regex.sub(lambda match: eval('"%s"' % match.group()), s)

def casetype(string):
    # Starts with lower case
    case = 0

    # Capitalized?
    if len(string) >= 1 and string[0].isupper():
        case = 1

        # All upper case?
        all_upper = True
        for i in range(1, len(string)):
            if not string[i].isupper():
                all_upper = False
                break
        if all_upper:
            case = 2

    return case

def caselike(model, string):
    if len(string) > 0:
        case = casetype(model)
        if case == 1:
            string = string[0].upper() + string[1:]
        elif case == 2:
            string = string.upper()
    return string

def replace(instream, outstream, regex, before, after, encoding):
    patlen = len(before)
    sum = 0

    tonext = u''
    while True:
        block = instream.read(io.DEFAULT_BUFFER_SIZE)
        if len(block) == 0:
            break

        try:
            block = block.decode(encoding=encoding)
        except:
            warn("error decoding file; aborting!")
            return 0

        parts = regex.split(tonext + block)
        sum += len(parts) // 2
        lastpart = parts[-1]
        if lastpart:
            tonext = lastpart[-patlen:]
            parts[-1] = lastpart[:-len(tonext)]
        else:
            tonext = u''

        for i in range(1, len(parts), 2):
            if parts[i] != '':
                replace = after
                if regex.flags & re.I != 0:
                    replace = caselike(parts[i], after)
                parts[i] = replace

        joined_parts = ''.join(parts)
        outstream.write(joined_parts.encode(encoding=encoding))

    outstream.write(tonext.encode(encoding=encoding))

    return sum


# Create command line argument parser.
parser = argparse.ArgumentParser(description="Search and replace text in files.",
                                 formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('--version', action='version',
                    version="%(prog)s " + VERSION + '''
Copyright (C) 2004-2005 Göran Weinholt <weinholt@debian.org>
Copyright (C) 2004 Christian Häggström <chm@c00.info>
Copyright (C) 2016 Kevin Coyner <kcoyner@debian.org>
Copyright (C) 2017 Jochen Kupperschmidt <homework@nwsnet.de>
Copyright (C) 2018-2019 Reuben Thomas <rrt@sc3d.org>

rpl comes with ABSOLUTELY NO WARRANTY.
You may redistribute copies of rpl under the terms of the
GNU General Public License.
For more information about these matters, see the file named COPYING.''')

parser.add_argument("--encoding", metavar="ENCODING",
                    help="specify character set encoding")

parser.add_argument("-i", "--ignore-case",
                    action="store_true",
                    help="match case-insensitively")

parser.add_argument("-w", "--whole-words",
                    action="store_true",
                    help="whole words (old_string matches on word boundaries only)")

parser.add_argument("-b", "--backup",
                    action="store_true",
                    help="rename original file to file~ before replacing")

parser.add_argument("-q", "--quiet",
                    action="store_true",
                    help="quiet mode")

parser.add_argument("-v", "--verbose",
                    action="store_true",
                    help="verbose mode")

parser.add_argument("-s", "--dry-run",
                    action="store_true",
                    help="simulation mode")

parser.add_argument("-e", "--escape",
                    action="store_true",
                    help="expand escapes in old_string and new_string")

parser.add_argument("-p", "--prompt",
                    action="store_true",
                    help="prompt before modifying each file")

parser.add_argument("-f", "--force",
                    action="store_true",
                    help="ignore errors when trying to preserve permissions")

parser.add_argument("-d", "--keep-times",
                    action="store_true",
                    help="keep the modification times on modified files")

parser.add_argument("-a", "--hidden-files",
                    action="store_true",
                    help="process hidden files and directories (starting with a period)")

parser.add_argument('old_str', metavar='OLD-TEXT')
parser.add_argument('new_str', metavar='NEW-TEXT')
parser.add_argument('file', metavar='FILE', nargs='+')

args = parser.parse_args()

# See if all the files actually exist
files = args.file
for file in files:
    if not os.path.exists(file):
        warn("File \"{}\" not found".format(file))
        sys.exit(os.EX_DATAERR)

old_str = args.old_str
new_str = args.new_str
if new_str == "" and not args.quiet:
    print("Really DELETE all occurrences of {} ({})? (Y/[N]) ".format(
        old_str,
        "ignoring case" if args.ignore.case else "case sensitive"
    ), file=sys.stderr, end='')

    line = input()
    if line != "" and line[0] in "nN":
        warn("User cancelled operation.")
        sys.exit(os.EX_TEMPFAIL)

# Tell the user what is going to happen
warn("{} \"{}\" with \"{}\" ({}; {})".format(
    "Simulating replacement of" if args.dry_run else "Replacing",
    old_str,
    new_str,
    "ignoring case" if args.ignore_case else "case sensitive",
    "whole words only" if args.whole_words else "partial words matched",
))

if args.dry_run and not args.quiet:
    warn("The files listed below would be modified in a replace operation")

encoding = None
if args.encoding:
    encoding = args.encoding

if args.escape:
    old_str = unescape(old_str)
    new_str = unescape(new_str)

regex_str = re.escape(old_str)
if args.whole_words:
    regex_str = r"\b" + regex_str + r"\b"
regex = re.compile(r"(" + regex_str + r")", re.I if args.ignore_case else 0)

total_files = 0
total_matches = 0
files = get_files(files, args.verbose, args.hidden_files)
for filename, perms in files:
    total_files += 1

    # Open the input file
    try:
        f = open(filename, "rb")
    except IOError as e:
        warn("SKIPPING {}: cannot open for reading; error: {}".format(filename, e))
        continue

    # Create the output file
    try:
        o, tmp_path = tempfile.mkstemp("", ".tmp.")
        o = os.fdopen(o, "wb")
    except OSError as e:
        warn("SKIPPING {}: cannot create temp file; error: {}".format(filename, e))
        continue

    # Set permissions and owner
    try:
        os.chown(tmp_path, perms.st_uid, perms.st_gid)
        os.chmod(tmp_path, perms.st_mode)
    except OSError as e:
        warn("Unable to set owner/group/perms of {}; error: {}".format(filename, e))
        if args.force:
            warn("WARNING: New owner/group/perms may not match!\n")
        else:
            warn("SKIPPING {}!\n".format(filename))
            os.unlink(tmp_path)
            continue

    if args.verbose and not args.dry_run:
        warn("Processing: {}".format(filename))
    elif not args.quiet and not args.dry_run:
        print(".", file=sys.stderr, flush=True, end='')

    # If we don't have an explicit encoding, guess
    block = b''
    if encoding is None:
        detector = UniversalDetector()
        while True:
            next_block = f.read(io.DEFAULT_BUFFER_SIZE)
            if len(next_block) == 0: break
            block += next_block
            detector.feed(next_block)
            if detector.done: break
        f = io.BufferedReader(ChainStream([io.BytesIO(block), f]))

        detector.close()
        if detector.done:
            encoding = detector.result['encoding']
            if args.verbose:
                warn("guessed encoding '{}'".format(encoding))
            else:
                encoding = locale.getpreferredencoding(False)
                warn("could not guess encoding; using locale default '{}'".format(encoding))

    # Do the actual work now
    matches = replace(f, o, regex, old_str, new_str, encoding)

    f.close()
    o.close()

    if matches == 0:
        os.unlink(tmp_path)
        continue

    if args.dry_run:
        try:
            fn = os.path.realpath(filename)
        except OSError as e:
            fn = filename

        if not args.quiet:
            print("  {}".format(fn), file=sys.stderr)

        os.unlink(tmp_path)
        total_matches += matches
        continue

    if args.prompt:
        print("\nSave \"{}\"? ([Y]/N) ".format(filename), file=sys.stderr, end='')

        line = ""
        while line == "" or line[0] not in "Yy\nnN":
            line = input()

        if line[0] in "nN":
            print("Not saved", file=sys.stderr)
            os.unlink(tmp_path)
            continue

        print("Saved", file=sys.stderr)

    if args.backup:
        try:
            os.rename(filename, filename + "~")
        except OSError as e:
            warn("Error renaming {} to {}:".format(filename, filename + "~", e))
            continue

    # Rename the file
    try:
        os.rename(tmp_path, filename)
    except OSError as e:
        warn("Could not replace {} with {}; error: {}".format(tmp_path, filename, e))
        os.unlink(tmp_path)
        continue

    # Restore the times
    if args.keep_times:
        try:
            os.utime(filename, (perms.st_atime, perms.st_mtime))
        except OSError as e:
            warn("Error setting timestamps of {}: {}".format(filename, e))

    total_matches += matches

# We're about to exit, give a summary
if not args.quiet:
    warn("A total of {} matches {} in {} file{} searched".format(
        total_matches,
        "found" if args.dry_run else "replaced",
        total_files,
        "s" if total_files != 1 else "",
    ))
    if args.dry_run:
        warn("None replaced (simulation mode)")
