/*
 * spc: Slow Protocols Channel
 * Ethernet OAM Protocol IEEE 802.3ah (A.K.A. "slow protocols")
 */
#include <linux/gfp.h>
#include <linux/net.h>
#include <linux/poll.h>
#include <linux/wait.h>
#include <net/sock.h>
#include <bspchip.h>
#include <net/rtl/rtl_nic.h>

#define PF_SPC		AF_SPC
#define AF_SPC		28	/* Slow Protocols Channel	*/
#undef DEBUG

static struct net_device *default_dev;
static wait_queue_head_t spc_sleep;		/* Sock wait queue */
static struct sk_buff_head spc_receive_queue;	/* Incoming packets */
static const __u8 mcast_da[] = { 0x01, 0x80, 0xc2, 0x00, 0x00, 0x02 };
static const __u8 subtype = 0xaa;

static struct proto spc_proto = {
	.name = "SPC",
	.owner = THIS_MODULE,
	.obj_size = sizeof(struct sock),
};

static struct sock *alloc_spc_sk(struct net *net, struct socket *sock, int family)
{
	struct sock *sk;

	sk = sk_alloc(net, family, GFP_KERNEL, &spc_proto);
	if (!sk)
		return NULL;
	sock_init_data(sock, sk);
	return sk;
}

static int create(struct net *net, struct socket *sock, int protocol, int family)
{
	struct sock *sk;

	sock->sk = NULL;
	if (sock->type == SOCK_STREAM)
		return -EINVAL;
	if (!(sk = alloc_spc_sk(net, sock, family)))
		return -ENOMEM;
	lock_sock(sk);
	sk->sk_bound_dev_if = default_dev->ifindex;
	release_sock(sk);

	return 0;
}

static int spc_release(struct socket *sock)
{
	if (sock->sk)
		sk_free(sock->sk);
	return 0;
}

static unsigned int spc_poll(struct file *file, struct socket *sock, poll_table * wait)
{
	unsigned int mask = 0;

	poll_wait(file, &spc_sleep, wait);
	if (skb_peek(&spc_receive_queue))
		mask |= POLLIN | POLLRDNORM;

	return mask;
}

static int spc_setsockopt(struct socket *sock, int level, int optname,
			  char __user * optval, int optlen)
{
	char devname[IFNAMSIZ];
	int ret = 0;
	struct net_device *dev;
	struct sock *sk = sock->sk;

	if (level != SOL_SOCKET)
		return -ENOPROTOOPT;

	if (optlen < 0)
		return -EINVAL;

	lock_sock(sk);
	switch (optname) {
	case SO_BINDTODEVICE:
		if (optlen > IFNAMSIZ - 1)
			optlen = IFNAMSIZ - 1;
		memset(devname, 0, sizeof(devname));
		if (copy_from_user(devname, optval, optlen)) {
			ret = -EFAULT;
			break;
		}

		dev = dev_get_by_name(&init_net, devname);
		if (dev == NULL) {
			ret = -ENODEV;
			break;
		}

		sk->sk_bound_dev_if = dev->ifindex;
		dev_put(dev);
		break;

	default:
		ret = -ENOPROTOOPT;
	}
	release_sock(sk);

	return ret;
}

static int spc_getsockopt(struct socket *sock, int level, int optname, char __user *optval, int __user *optlen)
{
	struct sock *sk = sock->sk;
	char devname[IFNAMSIZ];
	struct net_device *dev;
	void *valptr;
	int val = 0;
	int length;

	if (level != SOL_SOCKET)
		return -ENOPROTOOPT;

	if (get_user(length, optlen))
		return -EFAULT;

	if (length < 1)
		return -EFAULT;

	length = min_t(unsigned int, length, sizeof(devname));

	lock_sock(sk);
	switch (optname) {
	case SO_BINDTODEVICE:
		dev = dev_get_by_index(&init_net, sk->sk_bound_dev_if);

		if (dev) {
			strlcpy(devname, dev->name, length);
			length = strlen(devname) + 1;
		} else {
			*devname = '\0';
			length = 1;
		}

		valptr = (void *) devname;
		dev_put(dev);
		break;

	default:
		release_sock(sk);
		return -ENOPROTOOPT;
	}
	release_sock(sk);

	if (put_user(length, optlen))
		return -EFAULT;

	return copy_to_user(optval, valptr, length) ? -EFAULT : 0;
}

static int spc_rcv(struct sk_buff *skb, struct net_device *dev,
		   struct packet_type *pt, struct net_device *orig_dev)
{
	int ret = 0;

	if ((*(__u8 *)skb->data) == subtype) {
		skb_pull(skb, 1);
		skb_queue_tail(&spc_receive_queue, skb);
		wake_up(&spc_sleep);
	} else {
		kfree_skb(skb);
		ret = NET_RX_DROP;
	}

	return ret;
}

static int spc_recvmsg(struct kiocb *iocb, struct socket *sock,
		 struct msghdr *m, size_t total_len, int flags)
{
	DECLARE_WAITQUEUE(wait, current);
#ifdef DEBUG
	int k;
#endif
	int ret = 0;
	struct sk_buff *skb;

	if (flags & ~MSG_DONTWAIT)
		return -EOPNOTSUPP;

	add_wait_queue(&spc_sleep, &wait);
	set_current_state(TASK_INTERRUPTIBLE);
	while (!(skb = skb_dequeue(&spc_receive_queue))) {
		if (flags & MSG_DONTWAIT) {
			ret = -EAGAIN;
			break;
		}
		schedule();
		set_current_state(TASK_INTERRUPTIBLE);
		if (signal_pending(current)) {
			ret = -ERESTARTSYS;
			break;
		}
	}
	set_current_state(TASK_RUNNING);
	remove_wait_queue(&spc_sleep, &wait);

	if (ret < 0)
		goto end;

	m->msg_namelen = 0;

	if (skb) {
#ifdef DEBUG
		printk("\nI: SPC");
		for (k = 0; k < skb->len; k++) {
			if (k % 16 == 0)
				printk("\n  ");
			printk(" %2.2x", skb->data[k]);
		}
		printk("\n");
#endif

		total_len = min_t(size_t, total_len, skb->len);
		ret = skb_copy_datagram_iovec(skb, 0, m->msg_iov, total_len);
		if (ret == 0)
			ret = total_len;

#ifdef DEBUG
		printk(KERN_CRIT "%s:%d skb->len = %u, total_len = %u\n", __FUNCTION__, __LINE__,
				skb->len, total_len);
#endif

		kfree_skb(skb);
	} else
		ret = -EAGAIN;

end:
	return ret;
}

static int spc_sendmsg(struct kiocb *iocb, struct socket *sock,
		 struct msghdr *m, size_t total_len)
{
#ifdef DEBUG
	int k;
#endif
	int ret = 0;
	struct net_device *dev;
	struct sk_buff *skb;
	struct sock *sk = sock->sk;
	unsigned int len;

	if (m->msg_name)
		return -EISCONN;

	dev = dev_get_by_index(&init_net, sk->sk_bound_dev_if);
	if (dev == NULL)
		return -ENODEV;

	len = dev->hard_header_len + 1 + total_len;
	if (len > dev->hard_header_len + dev->mtu) {
		ret = -EMSGSIZE;
		goto end;
	}
	if (len < ETH_ZLEN)
		len = ETH_ZLEN;

	skb = sock_wmalloc(sk, len, 0, GFP_KERNEL);
	if (!skb) {
		ret = -ENOMEM;
		goto end;
	}
#ifdef DEBUG
	printk(KERN_CRIT "%s:%d mtu = %u, hard_header_len = %hu, total_len = %u, len = %u\n", __FUNCTION__, __LINE__,
			dev->mtu, dev->hard_header_len, total_len, len);
#endif

	/* Reserve space for headers. */
	skb_reserve(skb, dev->hard_header_len);
	skb_reset_network_header(skb);

	(*(__u8 *)skb_put(skb, 1)) = subtype;

	skb->dev = dev;
	skb->priority = sk->sk_priority;
	skb->protocol = cpu_to_be16(ETH_P_SLOW);

	ret = memcpy_fromiovec(skb_put(skb, total_len), m->msg_iov, total_len);
	if (ret < 0) {
		kfree_skb(skb);
		goto end;
	}

#ifdef DEBUG
	printk("\nO: SPC");
	for (k = 0; k < skb->len; k++) {
		if (k % 16 == 0)
			printk("\n  ");
		printk(" %2.2x", skb->data[k]);
	}
	printk("\n");
#endif

	dev_hard_header(skb, dev, ETH_P_SLOW, mcast_da, NULL, total_len);

	/* zero padding */
	if (skb->len < len) {
		len -= skb->len;
		memset(skb_put(skb, len), 0, len);
	}
	ret = skb->len;
	if (dev_queue_xmit(skb) != 0)
		ret = -ret;

end:
	dev_put(dev);
	return ret;
}

static struct proto_ops SOCKOPS_WRAPPED(spc_proto_ops) = {
	family:		PF_SPC,
	release:	spc_release,
	bind:		0,
	connect:	0,
	socketpair:	0,
	accept:		0,
	getname:	0,
	poll:		spc_poll,
	ioctl:		0,
	listen:		0,
	shutdown:	0,
	setsockopt:	spc_setsockopt,
	getsockopt:	spc_getsockopt,
	sendmsg:	spc_sendmsg,
	recvmsg:	spc_recvmsg,
	mmap:		0,
	sendpage:	0,
};

#include <linux/smp_lock.h>
SOCKOPS_WRAP(spc_proto, PF_SPC);

static int spc_create(struct net *net, struct socket *sock, int protocol)
{
	sock->ops = &spc_proto_ops;
	return create(net, sock, protocol, PF_SPC);
}

static const struct net_proto_family spc_family_ops = {
	.family = PF_SPC,
	.create = spc_create,
	.owner = THIS_MODULE,
};

static struct packet_type slow_packet_type __read_mostly = {
	.type =	cpu_to_be16(ETH_P_SLOW),
	.func =	spc_rcv,
};

static int __init spc_init(void)
{
	int ret = -1;

	read_lock(&dev_base_lock);
	for_each_netdev(&init_net, default_dev) {
		/* select the switch device */
		if (BSP_SW_IRQ != default_dev->irq)
			continue;

		/* select the ELAN device */
		if ((default_dev->priv_flags & IFF_DOMAIN_ELAN) == 0)
			continue;

		/* select the first port */
		if ((((struct dev_priv *)default_dev->priv)->portmask & 0x1) == 0)
			continue;

		ret = 0;
		break;
	}

	if (ret) {
		default_dev = __dev_get_by_name(&init_net, ALIASNAME_ELAN_PREFIX "2");
		if (!default_dev) {
			read_unlock(&dev_base_lock);
			goto end;
		}
	}
	dev_hold(default_dev);
	read_unlock(&dev_base_lock);

	ret = proto_register(&spc_proto, 1);
	if (ret) {
		printk(KERN_CRIT "%s: Cannot create SPC SLAB cache!\n",
		       __FUNCTION__);
		goto end_clean1;
	}

	ret = sock_register(&spc_family_ops);
	if (ret) {
		printk(KERN_ERR "SPC: can't register (%d)", ret);
		goto end_clean2;
	}

	init_waitqueue_head(&spc_sleep);
	skb_queue_head_init(&spc_receive_queue);
	dev_add_pack(&slow_packet_type);

end_clean2:
	proto_unregister(&spc_proto);
end_clean1:
	dev_put(default_dev);
end:
	return ret;
}

static void __exit spc_exit(void)
{
	dev_remove_pack(&slow_packet_type);
	sock_unregister(PF_SPC);
	proto_unregister(&spc_proto);
	dev_put(default_dev);
}

late_initcall(spc_init);
module_exit(spc_exit);