// SPDX-License-Identifier: GPL-2.0+

#include <avm/pa/avm_pa.h>
#include <linux/if_ether.h>
#include <linux/if_vlan.h>
#include <linux/if_pppox.h>
#include <linux/ppp_defs.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
#include <linux/udp.h>
#include <linux/tcp.h>
#include <linux/random.h>
#include <linux/etherdevice.h>
#include <net/ip6_checksum.h>

#define IPV6_ADDR_XOR(a) (a[0] ^ a[1] ^ a[2] ^ a[3])
struct hwpa_pkt_cb {
	u16 ipproto;
};

int hwpa_pktcmp(const struct sk_buff *skb1, const struct sk_buff *skb2)
{
	int i, first_d, last_d;

	first_d = last_d = -1;
	for (i = 0; i < min(skb1->len, skb2->len); i++) {
		u8 d = skb1->data[i] - skb2->data[i];
		if (d) {
			if (first_d == -1)
				first_d = i;
			last_d = i;
		}
	}

	if (first_d != -1) {
		printk(KERN_DEBUG
		       "skb1 (%pK) and skb2 (%pK) differ at offset %d:\n",
		       skb1, skb2, first_d);
		print_hex_dump_bytes("skb1: ", DUMP_PREFIX_ADDRESS,
				     &skb1->data[first_d],
				     (last_d + 1) - first_d);
		print_hex_dump_bytes("skb2: ", DUMP_PREFIX_ADDRESS,
				     &skb2->data[first_d],
				     (last_d + 1) - first_d);

		return 1;
	} else if (skb1->len != skb2->len) {
		printk(KERN_DEBUG
		       "skb1 (%pK) and skb2 (%pK) differ in len: %d vs. %d\n",
		       skb1, skb2, skb1->len, skb2->len);
		return 2;
	}

	return 0;
}

const void *hwpa_get_hdr(const struct avm_pa_pkt_match *match,
			 unsigned char type)
{
	unsigned int i;

	for (i = 0; i < match->nmatch; i++) {
		const struct avm_pa_match_info *info = &match->match[i];

		if (info->type == type) {
			return (void *)(HDRCOPY(match) + info->offset);
		}
	}

	return NULL;
}

struct sk_buff *hwpa_pkt_push_tcp(struct sk_buff *skb)
{
	struct tcphdr *tcp;
	struct hwpa_pkt_cb *cb = (void *)&skb->cb[0];

	cb->ipproto = IPPROTO_TCP;
	tcp = (void *)skb_push(skb, sizeof(*tcp));
	memset(tcp, 0, sizeof(*tcp));

	get_random_bytes(&tcp->source, sizeof(tcp->source));
	get_random_bytes(&tcp->dest, sizeof(tcp->dest));
	tcp->ack = 1;
	get_random_bytes(&tcp->seq, sizeof(tcp->seq));
	get_random_bytes(&tcp->ack_seq, sizeof(tcp->ack_seq));

	skb_partial_csum_set(skb, 0, offsetof(struct tcphdr, check));

	return skb;
}

struct sk_buff *hwpa_pkt_push_udp(struct sk_buff *skb)
{
	struct udphdr *udp;
	struct hwpa_pkt_cb *cb = (void *)&skb->cb[0];

	cb->ipproto = IPPROTO_UDP;
	udp = (void *)skb_push(skb, sizeof(*udp));
	memset(udp, 0, sizeof(*udp));

	get_random_bytes(&udp->source, sizeof(udp->source));
	get_random_bytes(&udp->dest, sizeof(udp->dest));
	udp->len = htons(skb->len);

	skb_partial_csum_set(skb, 0, offsetof(struct udphdr, check));

	return skb;
}

struct sk_buff *hwpa_pkt_push_ipv4(struct sk_buff *skb)
{
	struct iphdr *ip;
	struct hwpa_pkt_cb *cb = (void *)&skb->cb[0];
	unsigned char *csum_start;
	unsigned int csum_len;
	__sum16 *csum_storage;
	__wsum sum;

	skb->protocol = ETH_P_IP;
	ip = (void *)skb_push(skb, sizeof(*ip));
	if (skb_inner_network_offset(skb) < 0)
		skb_reset_inner_network_header(skb);
	skb_reset_network_header(skb);
	memset(ip, 0, sizeof(*ip));

	ip->ihl = 5;
	ip->version = 4;
	ip->ttl = 0x3f; /* PAE default for encapsulation */
	ip->tos = 0;
	ip->protocol = cb->ipproto;
	get_random_bytes(&ip->saddr, sizeof(ip->saddr));
	get_random_bytes(&ip->daddr, sizeof(ip->daddr));
	ip->id = htons(0xab90); /* PAE default for encapsulation */
	ip->frag_off = 0;
	ip->tot_len = htons(skb->len);
	ip->check = 0;
	ip->check = ip_fast_csum((void *)ip, ip->ihl);

	csum_start = skb->head + skb->csum_start;
	csum_len = skb_tail_pointer(skb) - csum_start;
	csum_storage = (__sum16 *)(csum_start + skb->csum_offset);

	sum = csum_partial(csum_start, csum_len, 0);
	*csum_storage = csum_tcpudp_magic(ip->saddr, ip->daddr, csum_len,
					  cb->ipproto, sum);

	cb->ipproto = IPPROTO_IPIP;

	return skb;
}

struct sk_buff *hwpa_pkt_push_ipv6(struct sk_buff *skb)
{
	struct ipv6hdr *ip;
	struct hwpa_pkt_cb *cb = (void *)&skb->cb[0];
	unsigned char *csum_start;
	unsigned int csum_len;
	__sum16 *csum_storage;
	__wsum sum;

	skb->protocol = ETH_P_IPV6;
	ip = (void *)skb_push(skb, sizeof(*ip));
	if (skb_inner_network_offset(skb) < 0)
		skb_reset_inner_network_header(skb);
	skb_reset_network_header(skb);
	memset(ip, 0, sizeof(*ip));

	ip->version = 6;
	ip->payload_len = htons(skb->len - sizeof(*ip));
	ip->nexthdr = cb->ipproto;
	ip->hop_limit = 0xff; /* PAE default for encapsulation */
	get_random_bytes(&ip->saddr, sizeof(ip->saddr));
	get_random_bytes(&ip->daddr, sizeof(ip->daddr));

	csum_start = skb->head + skb->csum_start;
	csum_len = skb_tail_pointer(skb) - csum_start;
	csum_storage = (__sum16 *)(csum_start + skb->csum_offset);

	sum = csum_partial(csum_start, csum_len, 0);
	*csum_storage = csum_ipv6_magic(&ip->saddr, &ip->daddr, csum_len,
					cb->ipproto, sum);

	cb->ipproto = IPPROTO_IPV6;

	return skb;
}

struct sk_buff *hwpa_pkt_push_pppoe(struct sk_buff *skb)
{
	struct pppoe_hdr *pppoeh;

	pppoeh = (void *)skb_push(skb, PPPOE_SES_HLEN);

	if (skb->protocol == ETH_P_IP) {
		pppoeh->tag[0].tag_type = htons(PPP_IP);
	} else {
		pppoeh->tag[0].tag_type = htons(PPP_IPV6);
	}
	skb->protocol = ETH_P_PPP_SES;

	/* "The VER field is four bits and MUST be set to 0x1 for this version
	 * of the PPPoE specification."
	 */
	pppoeh->ver = 1;

	/* "The TYPE field is four bits and MUST be set to 0x1 for this
	 * version of the PPPoE specification."
	 */
	pppoeh->type = 1;

	/* Session stage -> code = 0 */
	pppoeh->code = 0;

	get_random_bytes(&pppoeh->sid, sizeof(pppoeh->sid));
	pppoeh->length = htons(skb->len - 6);

	return skb;
}

struct sk_buff *hwpa_pkt_push_vlan(struct sk_buff *skb)
{
	struct vlan_hdr *vlanh;

	vlanh = (void *)skb_push(skb, sizeof(*vlanh));

	get_random_bytes(&vlanh->h_vlan_TCI, sizeof(vlanh->h_vlan_TCI));
	vlanh->h_vlan_TCI %= VLAN_N_VID;
	vlanh->h_vlan_encapsulated_proto = htons(skb->protocol);
	skb->protocol = ETH_P_8021Q;

	return skb;
}

struct sk_buff *hwpa_pkt_push_eth(struct sk_buff *skb)
{
	struct ethhdr *ethh;

	ethh = (void *)skb_push(skb, sizeof(*ethh));

	eth_random_addr(&ethh->h_dest[0]);
	eth_random_addr(&ethh->h_source[0]);
	ethh->h_proto = htons(skb->protocol);

	return skb;
}

struct sk_buff *hwpa_pkt_alloc(unsigned int payload)
{
	struct sk_buff *skb;

	skb = alloc_skb(AVM_PA_MAX_HEADER + payload, GFP_ATOMIC);
	skb_reserve(skb, AVM_PA_MAX_HEADER);
	memset(skb_put(skb, payload), 0, payload);

	return skb;
}

struct net_device *hwpa_get_netdev(avm_pid_handle pid)
{
	struct net_device *dev;
	struct net *net;

	rcu_read_lock();
	for_each_net_rcu (net) {
		for_each_netdev_rcu (net, dev) {
			if (AVM_PA_DEVINFO(dev)->pid_handle == pid) {
				dev_hold(dev);
				rcu_read_unlock();
				return dev;
			}
		}
	}
	rcu_read_unlock();

	return NULL;
}

/**
 * @fn struct net_device hwpa_get_and_hold_dev_master(struct net_device*)
 * @brief get master of net_device and hold a reference. Use dev_put if done.
 *
 * @param dev [in] net_device
 * @return master or NULL in case of error
 */
struct net_device *hwpa_get_and_hold_dev_master(struct net_device *dev)
{
	struct net_device *master;

	rcu_read_lock();
	master = netdev_master_upper_dev_get_rcu(dev);
	if (!master) {
		rcu_read_unlock();
		return NULL;
	}
	dev_hold(master);
	rcu_read_unlock();

	return master;
}

/**
 * @fn uint32_t hwpa_ipv4_gen_session_hash_raw(uint32_t, uint32_t, uint32_t, uint32_t, uint8_t)
 * @brief generate hash for ipv4 session properties. Needs to be symmetric.
 *
 * @param flow_ip [in] flow_ip used to generate hash
 * @param flow_ident [in] flow_ident used to generate hash
 * @param return_ip_xlate [in] return_ip_xlate used to generate hash
 * @param return_ident_xlate [in] return_ident_xlate used to generate hash
 * @param protocol [in] protocol used to generate hash
 *
 * @return the generated hash value
 */
uint32_t hwpa_ipv4_gen_session_hash_raw(uint32_t flow_ip, uint32_t flow_ident,
				       uint32_t return_ip_xlate, uint32_t return_ident_xlate,
				       uint8_t protocol)
{
	uint32_t hash = 0;

	hash ^= flow_ident;
	hash ^= flow_ip;
	hash ^= (uint32_t) protocol;
	hash ^= return_ip_xlate;
	hash ^= return_ident_xlate;
	return hash;
}

/**
 * @fn uint32_t hwpa_nss_ipv6_gen_session_hash_raw(uint32_t*, uint32_t, uint32_t*, uint32_t, uint8_t)
 * @brief generate hash for ipv6 session properties. Needs to be symmetric.
 *
 * @param flow_ip [in] flow_ip used to generate hash
 * @param flow_ident [in] flow_ident used to generate hash
 * @param return_ip [in] return_ip used to generate hash
 * @param return_ident [in] return_ident used to generate hash
 * @param protocol [in] protocol used to generate hash
 *
 * @return the generated hash value
 */
uint32_t hwpa_ipv6_gen_session_hash_raw(uint32_t *flow_ip, uint32_t flow_ident,
				       uint32_t *return_ip, uint32_t return_ident,
				       uint8_t protocol)
{
	uint32_t hash = 0;

	hash ^= flow_ident;
	hash ^= IPV6_ADDR_XOR(flow_ip);
	hash ^= (uint32_t) protocol;
	hash ^= IPV6_ADDR_XOR(return_ip);
	hash ^= return_ident;

	return hash;
}