#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2012 Red Hat, Inc.
#
# Authors:
# Thomas Woerner <twoerner@redhat.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import os, os.path
from copy import copy

from firewall.config import *
from firewall.core.io.firewalld_conf import firewalld_conf
from firewall.core.io.zone import Zone, zone_reader, zone_writer
from optparse import Option, OptionError, OptionParser, Values, \
    SUPPRESS_HELP, BadOptionError, OptionGroup
from firewall.functions import getPortID, getPortRange, getServiceName, \
    checkIP, checkInterface

# check for root user
if os.getuid() != 0:
    print(_("You need to be root to run %s.") % sys.argv[0])
    sys.exit(-1)

def usage():
    print("Usage: %s -h | --help" % sys.argv[0])

def __fail(msg=None):
    if msg:
        print(msg)
    sys.exit(2)

# system-config-firewall: fw_parser

def _check_port(option, opt, value):
    failure = False
    try:
        (ports, protocol) = value.split(":")
    except:
        failure = True
    else:
        range = getPortRange(ports.strip())
        if range < 0:
            failure = True
        elif range == None:
            raise OptionError(_("port range %s is not unique.") % value, opt)
        elif len(range) == 2 and range[0] >= range[1]:
            raise OptionError(_("%s is not a valid range (start port >= end "
                                "port).") % value, opt)
    if not failure:
        protocol = protocol.strip()
        if protocol not in [ "tcp", "udp" ]:
            raise OptionError(_("%s is not a valid protocol.") % protocol, opt)
    if failure:
        raise OptionError(_("invalid port definition %s.") % value, opt)
    return (ports.strip(), protocol)

def _check_forward_port(option, opt, value):
    result = { }
    error = None
    splits = value.split(":", 1)
    while len(splits) > 0:
        key_val = splits[0].split("=")
        if len(key_val) != 2:
            error = _("Invalid argument %s") % splits[0]
            break
        (key, val) = key_val
        if (key == "if" and checkInterface(val)) or \
                (key == "proto" and val in [ "tcp", "udp" ]) or \
                (key == "toaddr" and checkIP(val)):
            result[key] = val
        elif (key == "port" or key == "toport") and getPortRange(val) > 0:
            result[key] = val
        else:
            error = _("Invalid argument %s") % splits[0]
            break
        if len(splits) > 1:
            if splits[1].count("=") == 1:
                # last element
                splits = [ splits[1] ]
            else:
                splits = splits[1].split(":", 1)
        else:
            # finish
            splits.pop()

    if error:
        dict = { "option": opt, "value": value, "error": error }
        raise OptionError(_("option %(option)s: invalid forward_port "
                                 "'%(value)s': %(error)s.") % dict, opt)

    error = False
    for key in [ "if", "port", "proto" ]:
        if key not in result.keys():
            error = True
    if not "toport" in result.keys() and not "toaddr" in result.keys():
        error = True
    if error:
        dict = { "option": opt, "value": value }
        raise OptionError(_("option %(option)s: invalid forward_port "
                                 "'%(value)s'.") % dict, opt)

    return result

def _check_interface(option, opt, value):
    if not checkInterface(value):
        raise OptionError(_("invalid interface '%s'.") % value, opt)
    return value

def _append_unique(option, opt, value, parser, *args, **kwargs):
    vals = getattr(parser.values, option.dest)
    if vals and value in vals:
        return
    parser.values.ensure_value(option.dest, []).append(value)

class _Option(Option):
    TYPES = Option.TYPES + ("port", "rulesfile", "service", "forward_port",
                            "icmp_type", "interface")
    TYPE_CHECKER = copy(Option.TYPE_CHECKER)
    TYPE_CHECKER["port"] = _check_port
    TYPE_CHECKER["forward_port"] = _check_forward_port
    TYPE_CHECKER["interface"] = _check_interface

def _addStandardOptions(parser):
    parser.add_option("--enabled",
                      action="store_true", dest="enabled", default=True,
                      help=_("Enable firewall (default)"))
    parser.add_option("--disabled",
                      action="store_false", dest="enabled",
                      help=_("Disable firewall"))
    parser.add_option("--addmodule",
                      action="callback", dest="add_module", type="string",
                      metavar=_("<module>"),  callback=_append_unique,
                      help=_("Enable an iptables module"))
    parser.add_option("--removemodule",
                      action="callback", dest="remove_module", type="string",
                      metavar=_("<module>"), callback=_append_unique,
                      help=_("Disable an iptables module"))
    parser.add_option("-s", "--service",
                      action="callback", dest="services", type="service",
                      default=[ ],
                      metavar=_("<service>"), callback=_append_unique,
                      help=_("Open the firewall for a service (e.g, ssh)"))
    parser.add_option("-p", "--port",
                      action="callback", dest="ports", type="port",
                      metavar=_("<port>[-<port>]:<protocol>"),
                      callback=_append_unique,
                      help=_("Open specific ports in the firewall "
                             "(e.g, ssh:tcp)"))
    parser.add_option("-t", "--trust",
                      action="callback", dest="trust", type="interface",
                      metavar=_("<interface>"), callback=_append_unique,
                      help=_("Allow all traffic on the specified device"))
    parser.add_option("-m", "--masq",
                      action="callback", dest="masq", type="interface",
                      metavar=_("<interface>"), callback=_append_unique,
                      help=_("Masquerades traffic from the specified device. "
                             "This is IPv4 only."))
    parser.add_option("--custom-rules",
                      action="callback", dest="custom_rules", type="rulesfile",
                      metavar=_("[<type>:][<table>:]<filename>"),
                      callback=_append_unique,
                      help=_("Specify a custom rules file for inclusion in "
                             "the firewall, after the "
                             "default rules. Default protocol type: ipv4, "
                             "default table: filter. "
                             "(Example: ipv4:filter:/etc/sysconfig/"
                             "ipv4_filter_addon)"))
    parser.add_option("--forward-port",
                      action="callback", dest="forward_port",
                      type="forward_port",
                      metavar=_("if=<interface>:port=<port>:proto=<protocol>"
                                "[:toport=<destination port>]"
                                "[:toaddr=<destination address>]"),
                      callback=_append_unique,
                      help=_("Forward the port with protocol for the "
                             "interface to either another local destination "
                             "port (no destination address given) or to an "
                             "other destination address with an optional "
                             "destination port. This is IPv4 only."))
    parser.add_option("--block-icmp",
                      action="callback", dest="block_icmp", type="icmp_type",
                      default=[ ],
                      callback=_append_unique,
                      metavar=_("<icmp type>"),
                      help=_("Block this ICMP type. The default is to accept "
                             "all ICMP types."))

def _parse_args(parser, args, options=None):
    try:
        (_options, _args) = parser.parse_args(args, options)
    except Exception as error:
        parser.error(error)
        return None

    if len(_args) != 0:
        for arg in _args:
            parser.error(_("no such option: %s") % arg)
    if parser._fw_exit:
        sys.exit(2)
    if not hasattr(_options, "filename"):
        _options.filename = None
    if not hasattr(_options, "converted"):
        _options.converted = False
    return _options

class _OptionParser(OptionParser):
    # overload print_help: rhpl._ returns UTF-8
    def print_help(self, file=None):
        if file is None:
            file = sys.stdout

        str = self.format_help()
        if isinstance(str, unicode):
            encoding = self._get_encoding(file)
            str = str.encode(encoding, "replace")
        file.write(str)
    def print_usage(self, file=None):
        pass
    def exit(self, status=0, msg=None):
        if msg:
            print >>sys.stderr, msg
    def error(self, msg):
        if self._fw_source:
            text = "%s: %s" % (self._fw_source, msg)
        else:
            text = str(msg)
        self.exit(2, msg=text)
    def _match_long_opt(self, opt):
        if self._long_opt.has_key(opt):
            return opt
        raise BadOptionError(opt)
    def _process_long_opt(self, rargs, values):
        # allow to ignore errors in the ui
        try:
            self.__process_long_opt(rargs, values)
        except Exception as msg:
            self.error(msg)
    def _process_short_opts(self, rargs, values):
        # allow to ignore errors in the ui
        try:
            OptionParser._process_short_opts(self, rargs, values)
        except Exception as msg:
            self.error(msg)
    def __process_long_opt(self, rargs, values):
        arg = rargs.pop(0)

        # Value explicitly attached to arg?  Pretend it's the next
        # argument.
        if "=" in arg:
            (opt, next_arg) = arg.split("=", 1)
            had_explicit_value = True
        else:
            opt = arg
            had_explicit_value = False

        opt = self._match_long_opt(opt)
        option = self._long_opt[opt]
        if option.takes_value():
            nargs = option.nargs
            if len(rargs)+int(had_explicit_value) < nargs:
                if nargs == 1:
                    self.error(_("%s option requires an argument") % opt)
                else:
                    dict = { "option": opt, "count": nargs }
                    self.error(_("%(option)s option requires %(count)s "
                                 "arguments") % dict)
            elif nargs == 1 and had_explicit_value:
                value = next_arg
            elif nargs == 1:
                value = rargs.pop(0)
            elif had_explicit_value:
                value = tuple([ next_arg ] + rargs[0:nargs-1])
                del rargs[0:nargs-1]
            else:
                value = tuple(rargs[0:nargs])
                del rargs[0:nargs]

        elif had_explicit_value:
            self.error(_("%s option does not take a value") % opt)

        else:
            value = None

        option.process(opt, value, values, self)

def _gen_parser(source=None):
    parser = _OptionParser(option_class=_Option)
    parser._fw_source = source
    parser._fw_exit = False
    return parser

def parseSysconfigArgs(args, options=None, source=None):
    parser = _gen_parser(source)
    _addStandardOptions(parser)
    return _parse_args(parser, args, options)

# system-config-firewall: fw_sysconfig

def read_sysconfig_args():
    filename = None
    CONFIG = '/etc/sysconfig/system-config-firewall'
    if os.path.exists(CONFIG) and os.path.isfile(CONFIG):
        filename = CONFIG
    try:
        fd = open(filename, 'r')
    except:
        return None
    argv = [ ]
    for line in fd.xreadlines():
        if not line:
            break
        line = line.strip()
        if len(line) < 1 or line[0] == '#':
            continue
        argv.append(line)
    fd.close()
    return (argv, filename)

def parse_sysconfig_args(args, merge_config=None, filename=None):
    config = parseSysconfigArgs(args, options=merge_config, source=filename)
    if not config:
        return None
    config.filename = filename
    return config

def read_sysconfig_config(merge_config=None):
    args = read_sysconfig_args() # returns: (args, filename) or None
    if not args:
        return merge_config
    return parse_sysconfig_args(args[0], merge_config, args[1])

if len(sys.argv) > 1:
    # Parse the cmdline args and setup the initial firewall state
    conf = parse_sysconfig_args(None)
    if not conf:
        print ("Error: problem parsing arguments.")
        sys.exit(1)
else:
    # open system-config-firewall config
    conf = read_sysconfig_config()
    if not conf:
        print ("Error: problem reading system-config-firewall config file.")
        sys.exit(1)

# open firewalld config file to get default zone

zone = "public" # default zone in case of missing config file

_firewalld_conf = firewalld_conf(FIREWALLD_CONF)
try:
    _firewalld_conf.read()
except Exception as msg:
    # ignore read error, use default zone
    pass
else:
    zone = _firewalld_conf.get("DefaultZone")

obj = None
for path in [ ETC_FIREWALLD_ZONES, FIREWALLD_ZONES ]:
    filename = "%s.xml" % zone
    if os.path.exists("%s/%s" %(path, filename)):
        obj = zone_reader(filename, path)
        break

if not obj:
    # create new zone?
    sys.exit(0)

changed = False

# fields that can not get converted into a zone, needs NM work
if conf.enabled == False:
    print("Firewall was disabled, unable to convert to zone.")
if conf.trust:
    for dev in conf.trust:
        print("Device %s was trusted, unable to convert to zone." % dev)

# no custom rules
if conf.custom_rules and len(conf.custom_rules) > 0:
    for custom in conf.custom_rules:
        print("Ignoring custom-rule '%s'\n" % ":".join(custom))

# no modules
if conf.add_module and len(conf.add_module) > 0:
    for module in conf.add_module:
        print("Ignoring addmodule '%s'\n" % module)
if conf.remove_module and len(conf.remove_module) > 0:
    for module in conf.remove_module:
        print("Ignoring removemodule '%s'\n" % module)

if conf.masq:
    for dev in conf.masq:
        print("Device %s was masqueraded, enabling masquerade for the default zone." % dev)

if conf.ports and len(conf.ports) > 0:
    for item in conf.ports:
        if item not in obj.ports:
            print("Adding port %s/%s to default zone." % (item[0], item[1]))
            obj.ports.append(item)
            changed = True

if conf.services:
    for service in conf.services:
        if service not in obj.services:
            print("Adding service %s to default zone." % service)
            obj.services.append(service)
            changed = True

if conf.block_icmp:
    for icmp in conf.block_icmp:
        if icmp not in obj.icmp_blocks:
            print("Adding icmpblock %s to default zone." % icmp)
            obj.icmp_blocks.append(icmp)
            changed = True

if conf.forward_port:
    for fwd in conf.forward_port:
        # ignore interface, should belong to default zone
        entry = (fwd["port"], fwd["proto"], fwd["toport"], fwd["toaddr"])
        if entry not in obj.forward_ports:
            print("Adding forward port %s:%s:%s:%s to default zone." % \
                      (entry[0], entry[1], entry[2], entry[3]))
            obj.forward_ports.append(entry)
            changed = True

if changed:
    zone_writer(obj, ETC_FIREWALLD_ZONES)
else:
    print("No changes to default zone needed.")

sys.exit(0)
