/* SPDX-License-Identifier: GPL-2.0 */

#include "netlink.h"
#include <sys/types.h>
#include <sys/socket.h>
#include <linux/netlink.h>
#include <poll.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdint.h>
#include <stdbool.h>

/* write every received message to stdout */
#define DUMP_MSGS 0

int netlink_open(void)
{
	struct sockaddr_nl src_addr = { 0 };
	int fd = socket(PF_NETLINK, SOCK_RAW, NETLINK_AUDIT);

	if (fd < 0) {
		printf("Cannot open netlink socket: %s\n", strerror(errno));
		return -errno;
	}

	src_addr.nl_family = AF_NETLINK;
	src_addr.nl_pid = getpid();
	if (bind(fd, (struct sockaddr *)&src_addr, sizeof(src_addr))) {
		printf("Cannot bind netlink socket to port %d (%s)\n",
			src_addr.nl_pid, strerror(errno));
		close(fd);
		return -errno;
	}

	return fd;
}

int netlink_send(int fd, uint16_t type, const void *data, uint32_t size)
{
	static uint32_t sequence = 1;
	struct sockaddr_nl addr = { 0 };
	msg_buffer req = { 0 };
	struct nlmsghdr *nlh = (void *)req;
	int retval;

	if (NLMSG_SPACE(size) > sizeof(req))
		return -EINVAL;

	nlh->nlmsg_len = NLMSG_SPACE(size);
	nlh->nlmsg_type = type;
	nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
	nlh->nlmsg_seq = sequence++;
	memcpy(NLMSG_DATA(nlh), data, size);

	addr.nl_family = AF_NETLINK;
	addr.nl_pid = 0;
	addr.nl_groups = 0;

	do {
		retval = sendto(fd, &req, nlh->nlmsg_len, 0,
			(struct sockaddr *)&addr, sizeof(addr));
	} while (retval < 0 && errno == EINTR);
	
	if (retval < 0)
		return -errno;
	else if ((unsigned int)retval == nlh->nlmsg_len)
		return sequence - 1;
	else
		return -EINVAL;
}

bool netlink_receive(int fd, msg_buffer *msg, bool noblock)
{
	int flags = noblock ? MSG_DONTWAIT : 0;
	struct nlmsghdr *nlh = (void *)*msg;
	ssize_t bytes;

	/* -1 so we are always zero terminated */
	memset(*msg, 0, sizeof(*msg));
	bytes = recv(fd, *msg, sizeof(*msg) - 1, flags);

	if (bytes < 0) {
		printf("Error receiving msg: %s", strerror(errno));
		return false;
	}

	if (!NLMSG_OK(nlh, bytes)) {
		printf("Invalid netlink msg of size %zd received\n", bytes);
		return false;
	}

#if DUMP_MSGS
	printf("msg type=%u seq=%u pid=%d len=%u\n",
		nlh->nlmsg_type, nlh->nlmsg_seq,
		nlh->nlmsg_pid, nlh->nlmsg_len);
#endif

	return true;
}

/* send netlink message and wait for ack */
int netlink_request(int fd, uint16_t type, const void *data, uint32_t size, msg_buffer *reply)
{
	msg_buffer msg;
	struct nlmsghdr *nlh = (void *)msg;
	void *payload = NLMSG_DATA(nlh);
	int seq, success;

	seq = netlink_send(fd, type, data, size);
	if (seq < 0)
		return seq;

	/* wait for ack msg to be received */
	do {
		success = netlink_receive(fd, &msg, false);

	} while(!success || (int)nlh->nlmsg_seq != seq || nlh->nlmsg_type != NLMSG_ERROR);

	/* if an error occured or no reply is expected */
	if (((struct nlmsgerr *)payload)->error || !reply)
		return ((struct nlmsgerr *)payload)->error;


	/* kernel sends an reply in addition to ack */
	do {
		success = netlink_receive(fd, &msg, false);

	} while(!success || (int)nlh->nlmsg_seq != seq || nlh->nlmsg_type != type);

	memcpy(*reply, &msg, sizeof(*reply));

	return 0;
}