#include <linux/init.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/udp.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nfnetlink_conntrack.h>
#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_labels.h>
#include <net/netfilter/nf_conntrack_tuple.h>
#include <net/netfilter/nf_conntrack_l3proto.h>
#include <net/netfilter/nf_conntrack_l4proto.h>
#include "fastpath_core.h"
#include <net/rtl/rtl_types.h>
#include <net/rtl/rtl865x_nat.h>

#define DATA_AVAIL  0

struct nlsvr_t {	
	struct socket *sock;
	struct sockaddr_nl addr;
	struct work_struct work;
	u8   recvbuf[1024];
	unsigned long flags;
};

static struct nlsvr_t nlsvr;

static const struct nla_policy ct_nla_policy[CTA_MAX+1] = {
	[CTA_TUPLE_ORIG]	= { .type = NLA_NESTED },
	[CTA_TUPLE_REPLY]	= { .type = NLA_NESTED },
	[CTA_STATUS] 		= { .type = NLA_U32 },
	[CTA_PROTOINFO]		= { .type = NLA_NESTED },
	[CTA_HELP]		= { .type = NLA_NESTED },
	[CTA_NAT_SRC]		= { .type = NLA_NESTED },
	[CTA_TIMEOUT] 		= { .type = NLA_U32 },
	[CTA_MARK]		= { .type = NLA_U32 },
	[CTA_ID]		= { .type = NLA_U32 },
	[CTA_NAT_DST]		= { .type = NLA_NESTED },
	[CTA_TUPLE_MASTER]	= { .type = NLA_NESTED },
	[CTA_NAT_SEQ_ADJ_ORIG]  = { .type = NLA_NESTED },
	[CTA_NAT_SEQ_ADJ_REPLY] = { .type = NLA_NESTED },
	[CTA_ZONE]		= { .type = NLA_U16 },
	[CTA_MARK_MASK]		= { .type = NLA_U32 },
	[CTA_LABELS]		= { .type = NLA_BINARY,
				    .len = NF_CT_LABELS_MAX_SIZE },
	[CTA_LABELS_MASK]	= { .type = NLA_BINARY,
				    .len = NF_CT_LABELS_MAX_SIZE },
};

static const struct nla_policy tuple_nla_policy[CTA_TUPLE_MAX+1] = {
	[CTA_TUPLE_IP]		= { .type = NLA_NESTED },
	[CTA_TUPLE_PROTO]	= { .type = NLA_NESTED },
};

static const struct nla_policy proto_nla_policy[CTA_PROTO_MAX+1] = {
	[CTA_PROTO_NUM]	= { .type = NLA_U8 },
};


int netlink_parse_zone(const struct nlattr *attr, unsigned short *zone)
{
	if (attr)
#ifdef CONFIG_NF_CONNTRACK_ZONES
		*zone = ntohs(nla_get_be16(attr));
#else
		return -EOPNOTSUPP;
#endif
	else
		*zone = 0;

	return 0;
}


int netlink_parse_tuple_ip(struct nlattr *attr, struct nf_conntrack_tuple *tuple)
{
	struct nlattr *tb[CTA_IP_MAX+1];
	struct nf_conntrack_l3proto *l3proto;
	int ret = 0;

	ret = nla_parse_nested(tb, CTA_IP_MAX, attr, NULL);
	if (ret < 0)
		return ret;

	rcu_read_lock();
	l3proto = __nf_ct_l3proto_find(tuple->src.l3num);

	if (likely(l3proto->nlattr_to_tuple)) {
		ret = nla_validate_nested(attr, CTA_IP_MAX,
					  l3proto->nla_policy);
		if (ret == 0)
			ret = l3proto->nlattr_to_tuple(tb, tuple);
	} 

	rcu_read_unlock();

	return ret;
}

int netlink_parse_tuple_proto(struct nlattr *attr,
			    struct nf_conntrack_tuple *tuple)
{
	struct nlattr *tb[CTA_PROTO_MAX+1];
	struct nf_conntrack_l4proto *l4proto;
	int ret = 0;

	ret = nla_parse_nested(tb, CTA_PROTO_MAX, attr, proto_nla_policy);
	if (ret < 0)
		return ret;

	if (!tb[CTA_PROTO_NUM])
		return -EINVAL;
	tuple->dst.protonum = nla_get_u8(tb[CTA_PROTO_NUM]);

	rcu_read_lock();
	l4proto = __nf_ct_l4proto_find(tuple->src.l3num, tuple->dst.protonum);

	if (likely(l4proto->nlattr_to_tuple)) {
		ret = nla_validate_nested(attr, CTA_PROTO_MAX,
					  l4proto->nla_policy);
		if (ret == 0)
			ret = l4proto->nlattr_to_tuple(tb, tuple);
	}
	rcu_read_unlock();

	return ret;
}

int netlink_parse_tuple(struct nlattr * const cda[],
		      struct nf_conntrack_tuple *tuple,
		      enum ctattr_type type, u_int8_t l3num)
{
	struct nlattr *tb[CTA_TUPLE_MAX+1];
	int err;

	memset(tuple, 0, sizeof(*tuple));

	err = nla_parse_nested(tb, CTA_TUPLE_MAX, cda[type], tuple_nla_policy);
	if (err < 0)
		return err;

	if (!tb[CTA_TUPLE_IP])
		return -EINVAL;

	tuple->src.l3num = l3num;

	err = netlink_parse_tuple_ip(tb[CTA_TUPLE_IP], tuple);
	if (err < 0)
		return err;

	if (!tb[CTA_TUPLE_PROTO])
		return -EINVAL;

	err = netlink_parse_tuple_proto(tb[CTA_TUPLE_PROTO], tuple);
	if (err < 0)
		return err;

	/* orig and expect tuples get DIR_ORIGINAL */
	if (type == CTA_TUPLE_REPLY)
		tuple->dst.dir = IP_CT_DIR_REPLY;
	else
		tuple->dst.dir = IP_CT_DIR_ORIGINAL;

	return 0;
}

extern int rtl8676_del_L34Unicast_hwacc_ct(struct nf_conn *ct);
int fastpath_conntrack_event(unsigned char type, unsigned char u3, struct nf_conntrack_tuple orig_tuple, struct nf_conntrack_tuple reply_tuple)
{
	if (type == IPCTNL_MSG_CT_DELETE){
		if(u3==PF_INET)
		{
			if ( orig_tuple.dst.protonum != IPPROTO_ICMP ) {
#if 0
				const struct nf_conntrack_l4proto *l4proto;
				l4proto = __nf_ct_l4proto_find(orig_tuple.src.l3num, orig_tuple.dst.protonum);
				printk("do fastpath rule delete\n");
				printk("%-8s %u src=%pI4 dst=%pI4 ", l4proto->name, orig_tuple.dst.protonum, &orig_tuple.src.u3.ip, &orig_tuple.dst.u3.ip);
				printk("sport=%hu dport=%hu / ", orig_tuple.src.u.all, orig_tuple.dst.u.all);
				printk("src=%pI4 dst=%pI4 ", &reply_tuple.src.u3.ip, &reply_tuple.dst.u3.ip);
				printk("sport=%hu dport=%hu\n", reply_tuple.src.u.all, reply_tuple.dst.u.all);
#endif
				fastpath_delRoutedNaptConnection(orig_tuple, reply_tuple);
			}
		}
	}

	return SUCCESS;
}

static int ksocket_receive(struct socket* sock, unsigned char* buf, int len)
{	
	struct sockaddr_nl nladdr;
	struct msghdr msg;
	struct iovec iov;
	mm_segment_t oldfs;
	int size = 0;

	if (sock->sk==NULL) return 0;

	iov.iov_base = buf;
	iov.iov_len = len;

	msg.msg_flags = MSG_DONTWAIT;
	msg.msg_name = &nladdr;
	msg.msg_namelen  = sizeof(nladdr);
	msg.msg_control = NULL;
	msg.msg_controllen = 0;
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;
	msg.msg_control = NULL;

	memset(&nladdr, 0, sizeof(nladdr));
	nladdr.nl_family = AF_NETLINK;

	oldfs = get_fs();
	set_fs(KERNEL_DS);
	size = sock_recvmsg(sock,&msg,len,msg.msg_flags);
	set_fs(oldfs);

	return size;
}


static void fpnl_work_fn(struct work_struct *work) {
	struct nlsvr_t *f = container_of(work, struct nlsvr_t, work);
	struct nf_conntrack_tuple orig_tuple, reply_tuple;
	struct nlattr *cda[CTA_MAX + 1];
	struct nlattr *attr;
	struct nfgenmsg *nfmsg;
	unsigned char u3, type;
	unsigned short zone;
	int min_len, attrlen, ret, len,err;
	struct nlmsghdr *h = (struct nlmsghdr *)f->recvbuf;

	clear_bit(DATA_AVAIL, &f->flags);
	do {
		ret = ksocket_receive(f->sock, f->recvbuf, sizeof(f->recvbuf)); /* might need to repeat, so multiple events are handled  */
		
		if (ret <= 0)
			return;

		len = h->nlmsg_len;
		
		if (!NLMSG_OK(h, ret)) {
			printk("malformed message: len=%d\n", len);
			return;
		}

		type = (h->nlmsg_type & (IPCTNL_MSG_MAX-1));
//		printk("fpnl_work: type: %d\n", type);

		min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
		nfmsg = nlmsg_data(h);
		attr = (void *)h + min_len;
		u3 = nfmsg->nfgen_family;
		attrlen = h->nlmsg_len - min_len;;

		rcu_read_lock();
		err = nla_parse(cda, CTA_MAX, attr, attrlen, ct_nla_policy);
		if (err < 0){
			rcu_read_unlock();
			return;
		}

		err = netlink_parse_zone(cda[CTA_ZONE], &zone);
		if (err < 0){
			rcu_read_unlock();
			return;
		}

		err = netlink_parse_tuple(cda, &orig_tuple, CTA_TUPLE_ORIG, u3);
		if (err < 0){
			rcu_read_unlock();
			return;
		}
 
		err = netlink_parse_tuple(cda, &reply_tuple, CTA_TUPLE_REPLY, u3);
		if (err < 0){
			rcu_read_unlock();
			return;
		}
	
		rcu_read_unlock();
		
		fastpath_conntrack_event(type, u3, orig_tuple, reply_tuple);
			
	} while (ret > 0);
}


static void fpnl_sk_data_ready(struct sock *sk) {	
	struct nlsvr_t *f = (struct nlsvr_t *)sk->sk_user_data;
//	printk("%s(%d):\n",__func__,__LINE__);
	if (!test_and_set_bit(DATA_AVAIL, &f->flags)) {
		schedule_work(&f->work);
	}
}

static void setup_sock(struct sock *sk, void *data) {
	sk->sk_user_data = data;
	sk->sk_data_ready = fpnl_sk_data_ready;
}

int fastpath_event_init(void)
{
	int ret;
		
	ret = sock_create_kern(PF_NETLINK, SOCK_RAW, NETLINK_NETFILTER, &nlsvr.sock);
	
	if (ret < 0)
	{
		printk(KERN_INFO ": Could not create a datagram socket, error = %d\n", ret);		
	}	
	
	nlsvr.addr.nl_family = PF_NETLINK;      
	nlsvr.addr.nl_groups = (1<<(NFNLGRP_CONNTRACK_DESTROY-1));
	
	ret = kernel_bind(nlsvr.sock, (struct sockaddr *)&nlsvr.addr, sizeof(struct sockaddr));
	if ( ret < 0) 
	{
		printk(KERN_INFO ": Could not bind or connect to socket, error = %d\n", ret);		
	}
	
	setup_sock(nlsvr.sock->sk, (void *)&nlsvr);	
	
	INIT_WORK(&nlsvr.work, fpnl_work_fn);

	return ret;
}

void fastpath_event_cleanup(void)
{
}