/* Connection tracking via netlink socket. Allows for user space
 * protocol helpers and general trouble making from userspace.
 *
 * (C) 2001 by Jay Schulist <jschlst@samba.org>
 * (C) 2002-2005 by Harald Welte <laforge@gnumonks.org>
 * (C) 2003 by Patrick Mchardy <kaber@trash.net>,
 *
 * Initial connection tracking via netlink development funded and 
 * generally made possible by Network Robots, Inc. (www.networkrobots.com)
 *
 * Further development of this code funded by Astaro AG (http://www.astaro.com)
 *
 * This software may be used and distributed according to the terms
 * of the GNU General Public License, incorporated herein by reference.
 */

#include <linux/config.h>
#include <linux/module.h>
#include <linux/types.h>
#include <linux/socket.h>
#include <linux/kernel.h>
#include <linux/major.h>
#include <linux/sched.h>
#include <linux/timer.h>
#include <linux/string.h>
#include <linux/sockios.h>
#include <linux/net.h>
#include <linux/fcntl.h>
#include <linux/skbuff.h>
#include <linux/errno.h>
#include <asm/uaccess.h>
#include <asm/system.h>
#include <net/sock.h>
#include <linux/init.h>
#include <linux/netlink.h>
#include <linux/spinlock.h>
#include <linux/notifier.h>
#include <linux/rtnetlink.h>

#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/netfilter_ipv4/ip_tables.h>
#include <linux/netfilter_ipv4/ip_conntrack.h>
#include <linux/netfilter_ipv4/ip_conntrack_core.h>
#include <linux/netfilter_ipv4/ip_conntrack_helper.h>
#include <linux/netfilter_ipv4/ip_conntrack_protocol.h>

#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter_ipv4/ip_conntrack_netlink.h>

#define ASSERT_READ_LOCK(x) MUST_BE_READ_LOCKED(&ip_conntrack_lock)
#define ASSERT_WRITE_LOCK(x) MUST_BE_WRITE_LOCKED(&ip_conntrack_lock)
#include <linux/netfilter_ipv4/listhelp.h>

MODULE_LICENSE("GPL");

static char __initdata ctversion[] = "0.60";

#if 0
#define DEBUGP printk
#else
#define DEBUGP(format, args...)
#endif

static unsigned int event_mask = ~0UL;
static unsigned int dump_mask = ~0UL;
static struct nfnetlink_subsystem *ctnl_subsys;

static inline int
ctnetlink_dump_tuples(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	NFA_PUT(skb, CTA_ORIG, sizeof(struct ip_conntrack_tuple),
	        &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
	NFA_PUT(skb, CTA_RPLY, sizeof(struct ip_conntrack_tuple),
	        &ct->tuplehash[IP_CT_DIR_REPLY].tuple);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_status(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	NFA_PUT(skb, CTA_STATUS, sizeof(ct->status), &ct->status);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_timeout(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	unsigned long timeout = (ct->timeout.expires - jiffies) / HZ;
	
	NFA_PUT(skb, CTA_TIMEOUT, sizeof(timeout), &timeout);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_protoinfo(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	struct cta_proto cp;

	cp.num_proto = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum;
	memcpy(&cp.proto, &ct->proto, sizeof(cp.proto));
	NFA_PUT(skb, CTA_PROTOINFO, sizeof(cp), &cp);
	return 0;

nfattr_failure:
	return -1;
}

static inline int
ctnetlink_dump_helpinfo(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	struct ip_conntrack_helper *h = ct->helper;
	struct cta_help ch;

	if (h == NULL)
		memset(&ch, 0, sizeof(struct cta_help));
	else {
		strncpy((char *)&ch.name, h->name, sizeof(ch.name));
		memcpy(&ch.help, &ct->help, sizeof(ch.help));
	}
	NFA_PUT(skb, CTA_HELPINFO, sizeof(ch), &ch);
	return 0;

nfattr_failure:
	return -1;
}

#ifdef CONFIG_IP_NF_CT_ACCT
static inline int
ctnetlink_dump_counters(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	const struct ip_conntrack_counter *ctr;
	struct cta_counters cc;

	ctr = &ct->counters[IP_CT_DIR_ORIGINAL];
	memcpy(&cc.orig, ctr, sizeof(cc.orig));

	ctr = &ct->counters[IP_CT_DIR_REPLY];
	memcpy(&cc.reply, ctr, sizeof(cc.reply));
	
	NFA_PUT(skb, CTA_COUNTERS, sizeof(cc), &cc);
	return 0;

nfattr_failure:
	return -1;
}
#else
#define ctnetlink_dump_counters(a, b) (0)
#endif

#ifdef CONFIG_IP_NF_CONNTRACK_MARK
static inline int
ctnetlink_dump_mark(struct sk_buff *skb, const struct ip_conntrack *ct)
{
	NFA_PUT(skb, CTA_MARK, sizeof(unsigned long), &ct->mark);
	return 0;

nfattr_failure:
	return -1;
}
#else
#define ctnetlink_dump_mark(a, b) (0)
#endif

static int
ctnetlink_fill_info(struct sk_buff *skb, u32 pid, u32 seq,
		    int event, int nowait, 
		    const struct ip_conntrack *ct)
{
	struct nlmsghdr *nlh;
	struct nfgenmsg *nfmsg;
	unsigned char *b;

	b = skb->tail;

	event |= NFNL_SUBSYS_CTNETLINK << 8;
	nlh    = NLMSG_PUT(skb, pid, seq, event, sizeof(struct nfgenmsg));
	nfmsg  = NLMSG_DATA(nlh);

	nlh->nlmsg_flags    = (nowait && pid) ? NLM_F_MULTI : 0;
	nfmsg->nfgen_family = AF_INET;

	if (((dump_mask & DUMP_TUPLE) 
	     && (ctnetlink_dump_tuples(skb, ct) < 0)) ||
	    ((dump_mask & DUMP_STATUS) 
	     && (ctnetlink_dump_status(skb, ct) < 0)) ||
	    ((dump_mask & DUMP_TIMEOUT)
	     && (ctnetlink_dump_timeout(skb, ct) < 0)) ||
	    ((dump_mask & DUMP_COUNTERS)
	     && (ctnetlink_dump_counters(skb, ct) < 0)) ||
	    ((dump_mask & DUMP_PROTOINFO)
	     && (ctnetlink_dump_protoinfo(skb, ct) < 0)) ||
	    ((dump_mask & DUMP_HELPINFO)
	     && (ctnetlink_dump_helpinfo(skb, ct) < 0)) ||
	    ((dump_mask & DUMP_MARK)
	     && (ctnetlink_dump_mark(skb, ct) < 0)))
		goto nfattr_failure;

	nlh->nlmsg_len = skb->tail - b;
	return skb->len;

nlmsg_failure:
nfattr_failure:
	skb_trim(skb, b - skb->data);
	return -1;
}

static inline unsigned int
ctnetlink_get_mcgroups(struct ip_conntrack *ct)
{
	unsigned int groups;
	int proto = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum;

	switch (proto) {
	case IPPROTO_TCP:
		groups = NFGRP_IPV4_CT_TCP;
		break;
	case IPPROTO_UDP:
		groups = NFGRP_IPV4_CT_UDP;
		break;
	case IPPROTO_ICMP:
		groups = NFGRP_IPV4_CT_ICMP;
		break;
	default:
		groups = NFGRP_IPV4_CT_OTHER;
		break;
	}

	return groups;
}

static int ctnetlink_conntrack_event(struct notifier_block *this,
                                     unsigned long events, void *ptr)
{
	struct nlmsghdr *nlh;
	struct nfgenmsg *nfmsg;
	struct ip_conntrack *ct = (struct ip_conntrack *)ptr;
	struct sk_buff *skb;
	unsigned int type;
	unsigned char *b;
	int flags = 0;

	/* ignore our fake conntrack entry */
	if (ct == &ip_conntrack_untracked)
		return NOTIFY_DONE;

	/* netlink_trim now reduces the impact of using NLMSG_GOODSIZE */
	skb = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC);
	if (!skb)
		return NOTIFY_DONE;

	if (event_mask & events & IPCT_DESTROY)
		type = IPCTNL_MSG_CT_DELETE;
	else if (event_mask & events & (IPCT_NEW | IPCT_RELATED)) {
		type = IPCTNL_MSG_CT_NEW;
		flags = NLM_F_CREATE|NLM_F_EXCL;
		/* dump everything */
		events = ~0UL;
	} else if (event_mask & events & 
			     (IPCT_STATUS |
			     IPCT_PROTOINFO |
			     IPCT_HELPER |
			     IPCT_HELPINFO |
			     IPCT_NATINFO)) {
		type = IPCTNL_MSG_CT_NEW;
	} else {
		kfree_skb(skb);
		return NOTIFY_DONE;
	}

	b = skb->tail;

	type |= NFNL_SUBSYS_CTNETLINK << 8;
	nlh   = NLMSG_PUT(skb, 0, 0, type, sizeof(struct nfgenmsg));
	nfmsg = NLMSG_DATA(nlh);

	nlh->nlmsg_flags    = flags;
	nfmsg->nfgen_family = AF_INET;

	if (ctnetlink_dump_tuples(skb, ct) < 0)
		goto nfattr_failure;

	/* NAT stuff is now a status flag */
	if (((event_mask & events & IPCT_STATUS) 
	    || (event_mask & events & IPCT_NATINFO))
	    && ctnetlink_dump_status(skb, ct) < 0)
		goto nfattr_failure;
	if (event_mask & events & IPCT_REFRESH
	    && ctnetlink_dump_timeout(skb, ct) < 0)
		goto nfattr_failure;
	if (event_mask & events & IPCT_PROTOINFO
	    && ctnetlink_dump_protoinfo(skb, ct) < 0)
		goto nfattr_failure;
	if (event_mask & events & IPCT_HELPINFO
	    && ctnetlink_dump_helpinfo(skb, ct) < 0)
		goto nfattr_failure;

#ifdef CONFIG_IP_NF_CT_ACCT
	if (ctnetlink_dump_counters(skb, ct) < 0)
		goto nfattr_failure;
#endif

	nlh->nlmsg_len = skb->tail - b;
	nfnetlink_send(skb, 0, ctnetlink_get_mcgroups(ct), 0);
	return NOTIFY_DONE;

nlmsg_failure:
nfattr_failure:
	kfree_skb(skb);
	return NOTIFY_DONE;
}

static const int cta_min[CTA_MAX] = {
	[CTA_ORIG-1]		= sizeof(struct ip_conntrack_tuple),
	[CTA_RPLY-1]		= sizeof(struct ip_conntrack_tuple),
	[CTA_STATUS-1]		= sizeof(unsigned long),
	[CTA_PROTOINFO-1]	= sizeof(struct cta_proto),
	[CTA_HELPINFO-1]	= sizeof(struct cta_help),
	[CTA_TIMEOUT-1]		= sizeof(unsigned long),
	[CTA_MARK-1]		= sizeof(unsigned long),
	[CTA_COUNTERS-1]	= sizeof(struct cta_counters),
	[CTA_DUMPMASK-1]	= sizeof(unsigned int),
	[CTA_EVENTMASK-1]	= sizeof(unsigned int),

	[CTA_EXP_TUPLE-1]	= sizeof(struct ip_conntrack_tuple),
	[CTA_EXP_MASK-1]	= sizeof(struct ip_conntrack_tuple),
	[CTA_EXP_SEQNO-1]	= sizeof(u_int32_t),
	[CTA_EXP_PROTO-1]	= sizeof(struct cta_exp_proto),
	[CTA_EXP_TIMEOUT-1]	= sizeof(unsigned long)
};

static int
ctnetlink_config(struct sock *ctnl, struct sk_buff *skb,
		 struct nlmsghdr *nlh, int *errp)
{
	unsigned int *mask;
	struct nfattr *cda[CTA_MAX];

	DEBUGP("entered %s\n", __FUNCTION__);

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;
	
	if (cda[CTA_DUMPMASK-1] &&
	    NFA_PAYLOAD(cda[CTA_DUMPMASK-1]) < cta_min[CTA_DUMPMASK-1])
		return -EINVAL;

	if (cda[CTA_DUMPMASK-1]) {
		mask = NFA_DATA(cda[CTA_DUMPMASK-1]);
		/* Invalid mask */
		if (!*mask)
			return -EINVAL;
		dump_mask = *mask;
		DEBUGP("new dump mask set:%u\n", dump_mask);
	}

	if (cda[CTA_EVENTMASK-1] &&
	   NFA_PAYLOAD(cda[CTA_EVENTMASK-1]) < cta_min[CTA_EVENTMASK-1])
		return -EINVAL;

	if (cda[CTA_EVENTMASK-1]) {
		mask = NFA_DATA(cda[CTA_EVENTMASK-1]);
		/* Invalid mask */
		if (!*mask)
			return -EINVAL;
		event_mask = *mask;
		DEBUGP("new event mask set:%u\n", event_mask);
	}

	DEBUGP("leaving\n");

	return 0;
}

static inline int ctnetlink_kill(struct ip_conntrack *i, void *data)
{
	struct ip_conntrack *t = (struct ip_conntrack *)data;

	if (!memcmp(&i->tuplehash[IP_CT_DIR_ORIGINAL], 
	            &t->tuplehash[IP_CT_DIR_ORIGINAL], 
	            sizeof(struct ip_conntrack_tuple_hash)))
		return 1;
	return 0;
}

static int
ctnetlink_del_conntrack(struct sock *ctnl, struct sk_buff *skb, 
			struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_tuple_hash *h;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];
	struct ip_conntrack *ct;

	DEBUGP("entered %s\n", __FUNCTION__);

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1] &&
	    NFA_PAYLOAD(cda[CTA_ORIG-1]) < cta_min[CTA_ORIG-1])
		return -EINVAL;

	if (cda[CTA_RPLY-1] &&
	    NFA_PAYLOAD(cda[CTA_RPLY-1]) < cta_min[CTA_RPLY-1])
		return -EINVAL;

	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else {
			DEBUGP("no tuple found in request\n");
			return -EINVAL;
		}
	}

	h = ip_conntrack_find_get(tuple, NULL);
	if (!h) {
		DEBUGP("tuple not found in conntrack hash:");
		DUMP_TUPLE(tuple);
		return -ENOENT;
	}

	ct = tuplehash_to_ctrack(h);
	if (del_timer(&ct->timeout)) {
		ip_conntrack_put(ct);
		ct->timeout.function((unsigned long)ct);
	}
	DEBUGP("leaving\n");

	return 0;
}

static int ctnetlink_done(struct netlink_callback *cb)
{
	DEBUGP("entered %s\n", __FUNCTION__);
	return 0;
}

static int
ctnetlink_dump_table(struct sk_buff *skb, struct netlink_callback *cb)
{
	struct ip_conntrack *ct = NULL;
	struct ip_conntrack_tuple_hash *h;
	struct list_head *i;

	DEBUGP("entered %s, last bucket=%lu id=%lu\n", __FUNCTION__, 
			cb->args[0], cb->args[1]);

	READ_LOCK(&ip_conntrack_lock);
	for (; cb->args[0] < ip_conntrack_htable_size; 
					cb->args[0]++, cb->args[1]=0) {
		list_for_each(i, &ip_conntrack_hash[cb->args[0]]) {
			h = (struct ip_conntrack_tuple_hash *) i;
			if (DIRECTION(h) != IP_CT_DIR_ORIGINAL)
				continue;
			ct = tuplehash_to_ctrack(h);
			if (ct->id <= cb->args[1])
				continue;
			if (ctnetlink_fill_info(skb, NETLINK_CB(cb->skb).pid,
		                        	cb->nlh->nlmsg_seq,
						IPCTNL_MSG_CT_NEW,
						1, ct) < 0)
				goto out;
			cb->args[1] = ct->id;
		}
	}
out:	READ_UNLOCK(&ip_conntrack_lock);

	DEBUGP("leaving, last bucket=%lu id=%lu\n", cb->args[0], cb->args[1]);

	return skb->len;
}

#ifdef CONFIG_IP_NF_CT_ACCT
static int
ctnetlink_dump_table_w(struct sk_buff *skb, struct netlink_callback *cb)
{
	struct ip_conntrack *ct = NULL;
	struct ip_conntrack_tuple_hash *h;
	struct list_head *i;

	DEBUGP("entered %s, last bucket=%lu id=%lu\n", __FUNCTION__, 
			cb->args[0], cb->args[1]);

	WRITE_LOCK(&ip_conntrack_lock);
	for (; cb->args[0] < ip_conntrack_htable_size; 
					cb->args[0]++, cb->args[1]=0) {
		list_for_each(i, &ip_conntrack_hash[cb->args[0]]) {
			h = (struct ip_conntrack_tuple_hash *) i;
			if (DIRECTION(h) != IP_CT_DIR_ORIGINAL)
				continue;
			ct = tuplehash_to_ctrack(h);
			if (ct->id <= cb->args[1])
				continue;
			if (ctnetlink_fill_info(skb, NETLINK_CB(cb->skb).pid,
		                        	cb->nlh->nlmsg_seq,
						IPCTNL_MSG_CT_NEW,
						1, ct) < 0)
				goto out;
			cb->args[1] = ct->id;

			memset(&ct->counters, 0, sizeof(ct->counters));
		}
	}
out:	WRITE_UNLOCK(&ip_conntrack_lock);

	DEBUGP("leaving, last bucket=%lu id=%lu\n", cb->args[0], cb->args[1]);

	return skb->len;
}
#endif

static int
ctnetlink_get_conntrack(struct sock *ctnl, struct sk_buff *skb, 
			struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_tuple_hash *h;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];
	struct ip_conntrack *ct;
	struct sk_buff *skb2 = NULL;
	int err;

	DEBUGP("entered %s\n", __FUNCTION__);

	if (nlh->nlmsg_flags & NLM_F_DUMP) {
		struct nfgenmsg *msg = NLMSG_DATA(nlh);
		u32 rlen;

		if (msg->nfgen_family != AF_INET)
			return -EAFNOSUPPORT;

		if (NFNL_MSG_TYPE(nlh->nlmsg_type) ==
					IPCTNL_MSG_CT_GET_CTRZERO) {
#ifdef CONFIG_IP_NF_CT_ACCT
			if ((*errp = netlink_dump_start(ctnl, skb, nlh,
						ctnetlink_dump_table_w,
						ctnetlink_done)) != 0)
				return -EINVAL;
#else
			return -ENOTSUPP;
#endif
		} else {
			if ((*errp = netlink_dump_start(ctnl, skb, nlh,
		      		                        ctnetlink_dump_table,
		                                	ctnetlink_done)) != 0)
			return -EINVAL;
		}

		rlen = NLMSG_ALIGN(nlh->nlmsg_len);
		if (rlen > skb->len)
			rlen = skb->len;
		skb_pull(skb, rlen);
		return 0;
	}

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1] &&
	    NFA_PAYLOAD(cda[CTA_ORIG-1]) < cta_min[CTA_ORIG-1])
		return -EINVAL;

	if (cda[CTA_RPLY-1] &&
	    NFA_PAYLOAD(cda[CTA_RPLY-1]) < cta_min[CTA_RPLY-1])
		return -EINVAL;

	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else
			return -EINVAL;
	}

	h = ip_conntrack_find_get(tuple, NULL);
	if (!h) {
		DEBUGP("tuple not found in conntrack hash:");
		DUMP_TUPLE(tuple);
		return -ENOENT;
	}
	DEBUGP("tuple found\n");
	ct = tuplehash_to_ctrack(h);

	skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC);
	if (!skb2) {
		ip_conntrack_put(ct);
		return -ENOMEM;
	}
	NETLINK_CB(skb2).dst_pid = NETLINK_CB(skb).pid;

	err = ctnetlink_fill_info(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, 
				  IPCTNL_MSG_CT_NEW, 1, ct);
	ip_conntrack_put(ct);
	if (err <= 0)
		goto nlmsg_failure;

	/* RFC: This simpifies user space handling --pablo */
        NLMSG_PUT(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq, NLMSG_DONE, 0);

	err = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
	if (err < 0)
		return err;

	DEBUGP("leaving\n");
	return 0;

nlmsg_failure:
	if (skb2)
		kfree_skb(skb2);
	return -1;
}

static int
ctnetlink_flush_conntrack(struct sock *ctnl, struct sk_buff *skb, 
			struct nlmsghdr *nlh, int *errp)
{
	DEBUGP("entered %s\n", __FUNCTION__);

	/* fill the bucket, flush the toilet */
	ip_conntrack_flush();
	
	DEBUGP("leaving\n");
	return 0;
}

/* TODO: Now we have to handle NAT modifications here */
static inline int
ctnetlink_change_status(struct ip_conntrack *ct, unsigned long *status)
{
	unsigned long d = ct->status ^ *status;

	if (d & (IPS_EXPECTED|IPS_CONFIRMED|IPS_DESTROYED))
		/* unchangeable */
		return -EINVAL;
	
	if (d & IPS_SEEN_REPLY && !(*status & IPS_SEEN_REPLY))
		/* SEEN_REPLY bit can only be set */
		return -EINVAL;

	if (d & IPS_ASSURED && !(*status & IPS_ASSURED))
		/* ASSURED bit can only be set */
		return -EINVAL;

	ct->status = *status;
	return 0;
}

static inline int
ctnetlink_change_protoinfo(struct ip_conntrack *ct, struct cta_proto *cp)
{
	struct ip_conntrack_protocol *icp;
	int proto = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum;
	
	if (cp->num_proto != proto)
		return -EINVAL;

	icp = ip_ct_find_proto(cp->num_proto);
	if (icp->change_check_proto
	    && icp->change_check_proto(&cp->proto) < 0)
		return -EINVAL;

	if (icp->change_proto)
		icp->change_proto(ct, &cp->proto);

	return 0;
}

static inline int
ctnetlink_change_helpinfo(struct ip_conntrack *ct, struct cta_help *h)
{
	struct ip_conntrack_helper *helper = ct->helper;
	struct ip_conntrack_tuple *reply;

	if (helper == NULL) {
		if (*h->name == '\0')
			return 0;
		if (ct->master)
			return -EINVAL;
		reply = &ct->tuplehash[IP_CT_DIR_REPLY].tuple;
		helper = ip_ct_find_helper(reply);
		if (helper == NULL)
			return -ENOENT;
	} else if (*h->name == '\0') {
		ip_ct_remove_expectations(ct);
		ct->helper = NULL;
		return 0;
	}

	h->name[CTA_HELP_MAXNAMESZ - 1] = '\0';
	if (strcmp(helper->name, h->name))
		return -EINVAL;

	ct->helper = helper;
	if (helper->change_help)
		helper->change_help(ct, &h->help);

	return 0;
}

static inline int
ctnetlink_change_timeout(struct ip_conntrack *ct, unsigned long *timeout)
{
	if (!del_timer(&ct->timeout))
		return -ETIME;
	ct->timeout.expires = jiffies + *timeout * HZ;
	add_timer(&ct->timeout);

	return 0;
}

static int
ctnetlink_change_conntrack(struct ip_conntrack *ct, struct nfattr *cda[])
{
	void *data;
	int err;

	DEBUGP("entered %s\n", __FUNCTION__);

	if (cda[CTA_STATUS-1]) {
		data = NFA_DATA(cda[CTA_STATUS-1]);
		if ((err = ctnetlink_change_status(ct, data)) < 0)
			return err;
	}
	if (cda[CTA_PROTOINFO-1]) {
		data = NFA_DATA(cda[CTA_PROTOINFO-1]);
		if ((err = ctnetlink_change_protoinfo(ct, data)) < 0)
			return err;
	}
	if (cda[CTA_HELPINFO-1]) {
		data = NFA_DATA(cda[CTA_HELPINFO-1]);
		if ((err = ctnetlink_change_helpinfo(ct, data)) < 0)
			return err;
	}
	if (cda[CTA_TIMEOUT-1]) {
		data = NFA_DATA(cda[CTA_TIMEOUT-1]);
		if ((err = ctnetlink_change_timeout(ct, data)) < 0)
			return err;
	}

	DEBUGP("all done\n");
	return 0;
}

static int
ctnetlink_create_conntrack(struct nfattr *cda[])
{
	struct ip_conntrack *ct;
	struct ip_conntrack_tuple *otuple, *rtuple;
	struct ip_conntrack_protocol *icp;
	struct cta_proto *proto;
	unsigned long *status;
	unsigned long *timeout;
	void *data;
	int err;

	DEBUGP("entered %s\n", __FUNCTION__);

	if (!(cda[CTA_ORIG-1] && cda[CTA_RPLY-1] && cda[CTA_STATUS-1] &&
	      cda[CTA_PROTOINFO-1] && cda[CTA_TIMEOUT-1])) {
		DEBUGP("required attribute(s) missing\n");
		return -EINVAL;
	}

	otuple  = NFA_DATA(cda[CTA_ORIG-1]);
	rtuple  = NFA_DATA(cda[CTA_RPLY-1]);
	timeout = NFA_DATA(cda[CTA_TIMEOUT-1]);

	status = NFA_DATA(cda[CTA_STATUS-1]);
	if (!(*status & IPS_CONFIRMED))
		return -EINVAL;	/* cannot create unconfirmed connections */

	proto = NFA_DATA(cda[CTA_PROTOINFO-1]);
	icp   = ip_ct_find_proto(proto->num_proto);

	if (icp->change_check_tuples
	    && icp->change_check_tuples(otuple, rtuple) < 0)
		return -EINVAL;
	
	if (icp->change_check_proto
	    && icp->change_check_proto(&proto->proto) < 0)
		return -EINVAL;

	ct = ip_conntrack_alloc(otuple, rtuple);
	if (ct == NULL)
		return -ENOMEM;

	ct->status = *status;
	ct->timeout.expires = jiffies + *timeout * HZ;

	if (icp->change_proto)
		icp->change_proto(ct, &proto->proto);

	if (cda[CTA_HELPINFO-1]) {
		data = NFA_DATA(cda[CTA_HELPINFO-1]);
		if ((err = ctnetlink_change_helpinfo(ct, data)) < 0) {
			ip_conntrack_free(ct);
			return err;
		}
	}

	__ip_conntrack_hash_insert(ct);
	add_timer(&ct->timeout);

	DEBUGP("conntrack with id %d inserterd\n", ct->id);
	return 0;
}

static int 
ctnetlink_new_conntrack(struct sock *ctnl, struct sk_buff *skb, 
			struct nlmsghdr *nlh, int *errp)
{
	struct nfattr *cda[CTA_MAX];
	struct ip_conntrack_tuple *otuple = NULL, *rtuple = NULL;
	struct ip_conntrack_tuple_hash *h = NULL;
	int i, err = 0;

	DEBUGP("entered %s\n", __FUNCTION__);

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	for (i = 0; i < CTA_MAX; i++)
		if (cda[i] && NFA_PAYLOAD(cda[i]) < cta_min[i]) {
			DEBUGP("attribute %u has incorrect size %u<%u\n", 
				i, NFA_PAYLOAD(cda[i]), cta_min[i]);
			return -EINVAL;
		}

	DEBUGP("all attribute sizes ok\n");

	if (cda[CTA_ORIG-1])
		otuple = NFA_DATA(cda[CTA_ORIG-1]);
	
	if (cda[CTA_RPLY-1])
		rtuple = NFA_DATA(cda[CTA_RPLY-1]);

	if (otuple == NULL && rtuple == NULL) {
		DEBUGP("no tuple found in request\n");
		return -EINVAL;
	}

	WRITE_LOCK(&ip_conntrack_lock);
	if (otuple)
		h = __ip_conntrack_find(otuple, NULL);
	if (h == NULL && rtuple)
		h = __ip_conntrack_find(rtuple, NULL);

	if (h == NULL) {
		DEBUGP("no such conntrack, create new\n");
		err = -ENOENT;
		if (nlh->nlmsg_flags & NLM_F_CREATE)
			err = ctnetlink_create_conntrack(cda);
		WRITE_UNLOCK(&ip_conntrack_lock);
		return err;
	}

	/* We manipulate the conntrack inside the global conntrack table lock,
	 * so there's no need to increase the refcount */
	DEBUGP("conntrack found\n");
	err = -EEXIST;
	if (!(nlh->nlmsg_flags & NLM_F_EXCL))
		err = ctnetlink_change_conntrack(tuplehash_to_ctrack(h), cda);

	WRITE_UNLOCK(&ip_conntrack_lock);
	return err;
}

/*********************************************************************** 
 * EXPECT 
 ***********************************************************************/ 

static inline int
ctnetlink_exp_dump_tuples(struct sk_buff *skb,
                          const struct ip_conntrack_expect *exp)
{
	NFA_PUT(skb, CTA_EXP_TUPLE, sizeof(struct ip_conntrack_tuple),
	        &exp->tuple);
	NFA_PUT(skb, CTA_EXP_MASK, sizeof(struct ip_conntrack_tuple),
		&exp->mask);
	return 0;
	
nfattr_failure:
	return -1;
}

static inline int
ctnetlink_exp_dump_timeout(struct sk_buff *skb, 
			   const struct ip_conntrack_expect *exp)
{
	NFA_PUT(skb, CTA_EXP_TIMEOUT, sizeof(unsigned long), &exp->timeout);
	return 0;
	
nfattr_failure:
	return -1;
}

static inline int
ctnetlink_exp_dump_proto(struct sk_buff *skb,
                         const struct ip_conntrack_expect *exp)
{
	return 0;
}

static int
ctnetlink_exp_fill_info(struct sk_buff *skb, u32 pid, u32 seq,
		    int event, 
		    int nowait, 
		    const struct ip_conntrack_expect *exp)
{
	struct nlmsghdr *nlh;
	struct nfgenmsg *nfmsg;
	unsigned char *b;

	b = skb->tail;

	event |= NFNL_SUBSYS_CTNETLINK << 8;
	nlh    = NLMSG_PUT(skb, pid, seq, event, sizeof(struct nfgenmsg));
	nfmsg  = NLMSG_DATA(nlh);

	nlh->nlmsg_flags    = (nowait && pid) ? NLM_F_MULTI : 0;
	nfmsg->nfgen_family = AF_INET;

	if (ctnetlink_exp_dump_tuples(skb, exp) < 0 ||
	    ctnetlink_exp_dump_timeout(skb, exp) < 0)
		goto nfattr_failure;

	nlh->nlmsg_len = skb->tail - b;
	return skb->len;

nlmsg_failure:
nfattr_failure:
	skb_trim(skb, b - skb->data);
	return -1;
}

static int ctnetlink_exp_dump(struct sk_buff *skb,
			      struct ip_conntrack_expect *exp)
{
	if (ctnetlink_exp_dump_tuples(skb, exp) < 0 ||
	    ctnetlink_exp_dump_timeout(skb, exp) < 0)
		return -1;

	return 0;
}

static inline unsigned int
ctnetlink_exp_get_mcgroups(struct ip_conntrack_expect *exp)
{
	unsigned int groups;
	u16 proto = exp->tuple.dst.protonum;

	switch (proto) {
	case IPPROTO_TCP:
		groups = NFGRP_IPV4_CT_TCP;
		break;
	case IPPROTO_UDP:
		groups = NFGRP_IPV4_CT_UDP;
		break;
	case IPPROTO_ICMP:
		groups = NFGRP_IPV4_CT_ICMP;
		break;
	default:
		groups = NFGRP_IPV4_CT_OTHER;
		break;
	}

	return groups;
}

static int ctnetlink_exp_event(struct notifier_block *this,
                               unsigned long events, void *ptr)
{
	struct nlmsghdr *nlh;
	struct nfgenmsg *nfmsg;
	struct ip_conntrack_expect *exp = (struct ip_conntrack_expect *)ptr;
	struct sk_buff *skb;
	unsigned int type;
	unsigned char *b;
	int flags = 0;

	/* FIXME: much too big, costs lots of socket buffer space */
	skb = alloc_skb(400 /* NLMSG_GOODSIZE */, GFP_ATOMIC);
	if (!skb)
		return NOTIFY_DONE;

	if (events & IPEXP_NEW) {
		type = IPCTNL_MSG_EXP_NEW;
		flags = NLM_F_CREATE|NLM_F_EXCL;
	} else {
		kfree_skb(skb);
		return NOTIFY_DONE;
	}

	b = skb->tail;

	type |= NFNL_SUBSYS_CTNETLINK_EXP << 8;
	nlh   = NLMSG_PUT(skb, 0, 0, type, sizeof(struct nfgenmsg));
	nfmsg = NLMSG_DATA(nlh);

	nlh->nlmsg_flags    = flags;
	nfmsg->nfgen_family = AF_INET;

	if (ctnetlink_exp_dump(skb, exp) < 0)
		goto nfattr_failure;

	nlh->nlmsg_len = skb->tail - b;
	nfnetlink_send(skb, 0, ctnetlink_exp_get_mcgroups(exp), 0);
	return NOTIFY_DONE;

nlmsg_failure:
nfattr_failure:
	kfree_skb(skb);
	return NOTIFY_DONE;
}

static int
ctnetlink_del_expect(struct sock *ctnl, struct sk_buff *skb, 
		     struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_expect *exp;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1] &&
	    NFA_PAYLOAD(cda[CTA_ORIG-1]) < cta_min[CTA_ORIG-1])
		return -EINVAL;

	if (cda[CTA_RPLY-1] &&
	    NFA_PAYLOAD(cda[CTA_RPLY-1]) < cta_min[CTA_RPLY-1])
		return -EINVAL;

	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else
			return -EINVAL;
	}

	/* bump usage count to 2 */
	exp = ip_conntrack_expect_find_get(tuple);
	if (!exp)
		return -ENOENT;

	/* after list removal, usage count == 1 */
	ip_conntrack_unexpect_related(exp);
	/* we have put what we 'get' above. after this line usage count == 0 */
	ip_conntrack_expect_free(exp);

	return 0;
}

static int
ctnetlink_exp_dump_build_msg(const struct ip_conntrack_expect *exp,
			 struct sk_buff *skb, u32 pid, u32 seq)
{
	int err, proto;

	proto = exp->tuple.dst.protonum;
	err = ctnetlink_exp_fill_info(skb, pid, seq, IPCTNL_MSG_EXP_NEW, 1, 
				      exp);
	if (err <= 0)
		goto nlmsg_failure;
	return 0;

nlmsg_failure:
	if (skb)
		kfree_skb(skb);
	return -1;
}

static int
ctnetlink_exp_dump_table(struct sk_buff *skb, struct netlink_callback *cb)
{
	DEBUGP("entered %s\n", __FUNCTION__);
	if (cb->args[0] == 0) {
		READ_LOCK(&ip_conntrack_lock);
		LIST_FIND(&ip_conntrack_expect_list, 
			  ctnetlink_exp_dump_build_msg,
			  struct ip_conntrack_expect *, skb,
			  NETLINK_CB(cb->skb).pid, cb->nlh->nlmsg_seq);
		READ_UNLOCK(&ip_conntrack_lock);
		cb->args[0] = 1;
	}
	DEBUGP("returning\n");

	return skb->len;
}


static int
ctnetlink_get_expect(struct sock *ctnl, struct sk_buff *skb, 
		     struct nlmsghdr *nlh, int *errp)
{
	struct ip_conntrack_expect *exp;
	struct ip_conntrack_tuple *tuple;
	struct nfattr *cda[CTA_MAX];
	struct sk_buff *skb2 = NULL;
	int err, proto;

	DEBUGP("entered %s\n", __FUNCTION__);
	if (nlh->nlmsg_flags & NLM_F_DUMP) {
		struct nfgenmsg *msg = NLMSG_DATA(nlh);
		u32 rlen;

		if (msg->nfgen_family != AF_INET)
			return -EAFNOSUPPORT;

		DEBUGP("starting dump\n");
			if ((*errp = netlink_dump_start(ctnl, skb, nlh,
		    				ctnetlink_exp_dump_table,
						ctnetlink_done)) != 0)
			return -EINVAL;
		rlen = NLMSG_ALIGN(nlh->nlmsg_len);
		if (rlen > skb->len)
			rlen = skb->len;
		skb_pull(skb, rlen);
		return 0;
	}

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	if (cda[CTA_ORIG-1]
	    && NFA_PAYLOAD(cda[CTA_ORIG-1]) < sizeof(struct ip_conntrack_tuple))
		return -EINVAL;

	if (cda[CTA_RPLY-1]
	    && NFA_PAYLOAD(cda[CTA_RPLY-1]) < sizeof(struct ip_conntrack_tuple))
		return -EINVAL;

	if (cda[CTA_ORIG-1])
		tuple = NFA_DATA(cda[CTA_ORIG-1]);
	else {
		if (cda[CTA_RPLY-1])
			tuple = NFA_DATA(cda[CTA_RPLY-1]);
		else
			return -EINVAL;
	}

	exp = ip_conntrack_expect_find_get(tuple);
	if (!exp)
		return -ENOENT;

	skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC);
	if (!skb2)
		return -ENOMEM;
	NETLINK_CB(skb2).dst_pid = NETLINK_CB(skb).pid;
	proto = exp->tuple.dst.protonum;
	
	err = ctnetlink_exp_fill_info(skb2, NETLINK_CB(skb).pid, 
				      nlh->nlmsg_seq, IPCTNL_MSG_EXP_NEW,
				      1, exp);
	if (err <= 0)
		goto nlmsg_failure;

	err = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
	if (err < 0)
		return err;
	return 0;

nlmsg_failure:
	if (skb2)
		kfree_skb(skb2);
	return -1;
}

static int
ctnetlink_change_expect(struct ip_conntrack_expect *x, struct nfattr *cda[])
{

	return -EOPNOTSUPP;
}

static int
ctnetlink_create_expect(struct nfattr *cda[])
{
	struct ip_conntrack_tuple *tuple, *mask;
	struct ip_conntrack_tuple *orig, *reply;
	struct ip_conntrack_tuple_hash *h;
	struct ip_conntrack_expect *exp;
	struct ip_conntrack_helper *helper;
	unsigned long timeout;
	int err;

	DEBUGP("entered %s\n", __FUNCTION__);

	if (!(cda[CTA_ORIG-1] || cda[CTA_RPLY-1])) {
		DEBUGP("required attributes missing\n");
		return -EINVAL;
	}

	tuple = NFA_DATA(cda[CTA_EXP_TUPLE-1]);
	mask  = NFA_DATA(cda[CTA_EXP_MASK-1]);
	orig  = NFA_DATA(cda[CTA_ORIG-1]);
	reply = NFA_DATA(cda[CTA_RPLY-1]);

	/* Look for master conntrack of this expectation */
	h = ip_conntrack_find_get(reply, NULL);
	if (h == NULL)
		return -ENOENT;

	helper = tuplehash_to_ctrack(h)->helper;

	if (cda[CTA_EXP_TIMEOUT-1])
		timeout = *(unsigned long *)NFA_DATA(cda[CTA_EXP_TIMEOUT-1]);
	else if (helper && helper->timeout)
		timeout = helper->timeout;
	else
		return -EINVAL;

	exp = ip_conntrack_expect_alloc();
	if (!exp)
		return -ENOMEM;
	
	exp->expectfn = NULL;
	exp->master = tuplehash_to_ctrack(h);
	memcpy(&exp->tuple, tuple, sizeof(struct ip_conntrack_tuple));
	memcpy(&exp->mask, mask, sizeof(struct ip_conntrack_tuple));

	exp->timeout.expires = jiffies + timeout * HZ;
	add_timer(&exp->timeout);

	err = ip_conntrack_expect_related(exp);
	if (err < 0)
		return err;
	
	return 0;
}

static int
ctnetlink_new_expect(struct sock *ctnl, struct sk_buff *skb,
		     struct nlmsghdr *nlh, int *errp)
{
	struct nfattr *cda[CTA_MAX];
	struct ip_conntrack_tuple *tuple, *mask;
	struct ip_conntrack_expect *exp;
	int i, err = 0;

	if (nfnetlink_check_attributes(ctnl_subsys, nlh, cda) < 0)
		return -EINVAL;

	for (i = 0; i < CTA_MAX; i++)
		if (cda[i] && NFA_PAYLOAD(cda[i]) < cta_min[i])
			return -EINVAL;

	if (!cda[CTA_EXP_TUPLE-1] || !cda[CTA_EXP_MASK-1])
		return -EINVAL;

	tuple = NFA_DATA(cda[CTA_EXP_TUPLE-1]);
	mask  = NFA_DATA(cda[CTA_EXP_MASK-1]);

	WRITE_LOCK(&ip_conntrack_lock);
	exp = __ip_conntrack_expect_find(tuple);

	if (exp == NULL) {
		err = -ENOENT;
		if (!(nlh->nlmsg_flags & NLM_F_CREATE))
			goto out_unlock;
		err = ctnetlink_create_expect(cda);
	} else {
		err = -EEXIST;
		if (nlh->nlmsg_flags & NLM_F_EXCL)
			goto out_unlock;
		err = ctnetlink_change_expect(exp, cda);
	}

out_unlock:
	WRITE_UNLOCK(&ip_conntrack_lock);
	return err;
}

/* struct conntrack_expect stuff */

static struct notifier_block ctnl_notifier = {
	ctnetlink_conntrack_event,
	NULL,
	0
};

static struct notifier_block ctnl_exp_noti = {
	ctnetlink_exp_event,
	NULL,
	0
};

static void __exit ctnetlink_exit(void)
{
	printk("ctnetlink: unregistering with nfnetlink.\n");
	ip_conntrack_expect_unregister_notifier(&ctnl_exp_noti);
	ip_conntrack_unregister_notifier(&ctnl_notifier);
	nfnetlink_subsys_unregister(ctnl_subsys);
	kfree(ctnl_subsys);
	return;
}

static int __init ctnetlink_init(void)
{
	int ret;

	ctnl_subsys = nfnetlink_subsys_alloc(IPCTNL_MSG_COUNT);
	if (!ctnl_subsys) {
		ret = -ENOMEM;
		goto err_out; 
	}

	ctnl_subsys->name = "conntrack";
	ctnl_subsys->subsys_id = NFNL_SUBSYS_CTNETLINK;
	ctnl_subsys->cb_count = IPCTNL_MSG_COUNT;
	ctnl_subsys->attr_count = CTA_MAX;
	ctnl_subsys->cb[IPCTNL_MSG_CT_NEW].call = ctnetlink_new_conntrack;
	ctnl_subsys->cb[IPCTNL_MSG_CT_NEW].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_CT_DELETE].call = ctnetlink_del_conntrack;
	ctnl_subsys->cb[IPCTNL_MSG_CT_DELETE].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_CT_GET].call = ctnetlink_get_conntrack;
	ctnl_subsys->cb[IPCTNL_MSG_CT_GET].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_CT_GET_CTRZERO].call =
						ctnetlink_get_conntrack;
	ctnl_subsys->cb[IPCTNL_MSG_CT_GET_CTRZERO].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_CT_FLUSH].call = ctnetlink_flush_conntrack;
	ctnl_subsys->cb[IPCTNL_MSG_CT_FLUSH].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_EXP_NEW].call = ctnetlink_new_expect;
	ctnl_subsys->cb[IPCTNL_MSG_EXP_NEW].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_EXP_DELETE].call = ctnetlink_del_expect;
	ctnl_subsys->cb[IPCTNL_MSG_EXP_DELETE].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_EXP_GET].call = ctnetlink_get_expect;
	ctnl_subsys->cb[IPCTNL_MSG_EXP_GET].cap_required = CAP_NET_ADMIN;
	ctnl_subsys->cb[IPCTNL_MSG_CONFIG].call = ctnetlink_config;
	ctnl_subsys->cb[IPCTNL_MSG_CONFIG].cap_required = CAP_NET_ADMIN;

	printk("ctnetlink v%s: registering with nfnetlink.\n", ctversion);
	if ((ret = nfnetlink_subsys_register(ctnl_subsys) < 0)) {
		printk("ctnetlink_init: cannot register with nfnetlink.\n");
		goto err_free_subsys;
	}

	if ((ret = ip_conntrack_register_notifier(&ctnl_notifier)) < 0) {
		printk("ctnetlink_init: cannot register notifier.\n");
		goto err_unreg_subsys;
	}

	if ((ret = ip_conntrack_expect_register_notifier(&ctnl_exp_noti)) < 0) {
		printk("ctnetlink_nit: cannot register expect notifier.\n");
		goto err_unreg_notify;
	}

	return 0;
	
err_unreg_notify:
	ip_conntrack_unregister_notifier(&ctnl_notifier);
err_unreg_subsys:
	nfnetlink_subsys_unregister(ctnl_subsys);
err_free_subsys:
	kfree(ctnl_subsys);
err_out:
	return ret;
}

module_init(ctnetlink_init);
module_exit(ctnetlink_exit);
