/* SPDX-License-Identifier: GPL-2.0 */

#include "netlink.h"
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/apparmor.h>
#include <linux/netlink.h>
#include <linux/audit.h>
#include <poll.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdint.h>
#include <stdbool.h>

/* path where featuredump is written to */
#define AA_FEATUREOUT "/tmp/apparmor_features"

/* timeout to wait for notifications before giving up (in ms) */
#define AUDIT_DUMP_TIMEOUT 200

/* parse_args returns one of these modes */
enum mode {
	MODE_HELP,
	_AUDIT_BEGIN,
	MODE_AUDIT_STATUS,
	MODE_AUDIT_LOGS,
	MODE_AUDIT_WATCH,
	MODE_AUDIT_SET_BL,
	_AUDIT_END,
	MODE_AA_FEATURE_DUMP,
};

/* parse_args sets this to the argument containing the backlog limit */
static const char *bl_arg;

static int audit_status(int fd, struct audit_status *out)
{
	msg_buffer msg;
	struct nlmsghdr *nlh = (void *)msg;
	void *payload = NLMSG_DATA(nlh);
	int ret;

	ret = netlink_request(fd, AUDIT_GET, NULL, 0, &msg);
	if (ret) {
		printf("Failed to get audit status: %d\n", ret);
		goto out;
	}
	memcpy(out, payload, sizeof(*out));
out:
	return ret;
}

static int audit_set_backloglimit(int fd, uint32_t limit)
{
	struct audit_status s = { 0 };

	s.mask = AUDIT_STATUS_BACKLOG_LIMIT;
	s.backlog_limit = limit;

	return netlink_request(fd, AUDIT_SET, &s, sizeof(s), NULL);
}

static void audit_print_status(const struct audit_status *s)
{
	printf("status: [enabled=%u failure_action=%u rate_limit=%u backlog_limit=%u lost=%u]\n"
		, s->enabled, s->failure, s->rate_limit, s->backlog_limit, s->lost);
}

static int audit_dump_logs(int fd, const uint32_t backlog_limit)
{
	struct audit_status s = { 0 };
	struct pollfd watch = {
		.fd = fd,
		.events = POLLIN,
	};
	bool attached = false;
	uint32_t num = 0;
	int ret = 0, seq;
	int timeout = backlog_limit ? AUDIT_DUMP_TIMEOUT : -1;

	/* attach this process as auditd */
	s.mask  = AUDIT_STATUS_PID;
	s.pid = getpid();
	seq = netlink_send(fd, AUDIT_SET, &s, sizeof(s));
	if (seq < 0)
		return seq;

	/* The second we attach we get notifications by the kernel.
	 * Therefore we cannot wait for the ack as it can be interleaved with
	 * notifications.
	 */
	while ((!backlog_limit || num < backlog_limit) && (ret = poll(&watch, 1, timeout)) > 0) {
		msg_buffer msg;
		struct nlmsghdr *nlh = (void *)msg;
		void *payload = NLMSG_DATA(nlh);

		if (!(watch.revents & POLLIN)) {
			printf("revent error: 0x%x\n", watch.revents);
			break;
		}

		if (!netlink_receive(fd, &msg, true))
			break;

		switch (nlh->nlmsg_type) {
			case NLMSG_ERROR: {
				const struct nlmsgerr *err = payload;
				if ((int)nlh->nlmsg_seq != seq || err->error)
					break;
				attached = true;
				break;
			}
			case AUDIT_AVC: {
				num++;
				printf("%s\n", (const char *)payload);
				break;
			}
		}
	}

	if (ret < 0) {
		printf("Error during poll: %s\n", strerror(errno));
	}
	printf("log_summary: [attached_successful=%d num=%u]\n", attached, num);

	/* detach in order to stop notifications */
	if (attached) {
		s.pid = 0;
		ret = netlink_request(fd, AUDIT_SET, &s, sizeof(s), NULL);
		if (ret)
			printf("Failed to detach: %d\n", ret);
	}

	return ret;
}

static int aa_dump_features(void)
{
	aa_features *features = NULL;
	int ret = 0;

	if (aa_features_new_from_kernel(&features)) {
		fprintf(stderr, "Failed to get features from kernel.\n");
		ret = -1;
		goto out;
	}

	if (aa_features_write_to_file(features, -1, AA_FEATUREOUT)) {
		fprintf(stderr, "Failed to write features to: %s.\n", AA_FEATUREOUT);
		ret = -1;
		goto features;
	}
	printf("Features dumped to: %s\n", AA_FEATUREOUT);

features:
	aa_features_unref(features);
out:
	return ret;
}

static int aa_set_bl(int fd, const char *arg)
{
	struct audit_status before, after;
	char *endptr;
	long limit;
	int ret = 0;

	limit = strtol(arg, &endptr, 0);
	if (arg == endptr) {
		printf("Invalid number: %s\n", arg);
		return EINVAL;
	}

	if (limit <= 0) {
		printf("Limit must be > 0\n");
		return EINVAL;
	}

	ret = audit_status(fd, &before);
		if (ret)
			return ret;

	ret = audit_set_backloglimit(fd, (uint32_t)limit);
	if (ret)
		return ret;

	ret = audit_status(fd, &after);
		if (ret)
			return ret;

	printf("backlog_limit %u -> %u\n", before.backlog_limit, after.backlog_limit);
	return ret;
}

static void print_usage(void)
{
	printf("usage: apparmor [option]\n");
	printf("    -s        Show audit status\n");
	printf("    -l        Dump audit status + backlog + summary to stdout\n");
	printf("    -w        Watch audit logs (they are not printed to kmesg while this is running)\n");
	printf("    -f        Dump apparmor features to %s\n", AA_FEATUREOUT);
	printf("    -b <num>  Sets audit backlog_limit to <num>\n");
}

static enum mode parse_args(int argc, char **argv)
{
	enum mode mode = MODE_HELP;
	const int result = getopt(argc, argv, "slwfb:");

	if (result == 's')
		mode = MODE_AUDIT_STATUS;
	else if (result == 'l')
		mode = MODE_AUDIT_LOGS;
	else if (result == 'w')
		mode = MODE_AUDIT_WATCH;
	else if (result == 'f')
		mode = MODE_AA_FEATURE_DUMP;
	else if (result == 'b') {
		mode = MODE_AUDIT_SET_BL;
		bl_arg = optarg;
	}

	return mode;
}

int main(int argc, char** argv)
{
	enum mode mode = parse_args(argc, argv);
	int fd = -1, ret = 0;

	if (mode > _AUDIT_BEGIN && mode < _AUDIT_END) {
		fd = netlink_open();
		if (fd < 0) {
			ret = fd;
			goto out;
		}
	}

	if (mode == MODE_AUDIT_STATUS || mode == MODE_AUDIT_LOGS) {
		struct audit_status status;

		ret = audit_status(fd, &status);
		if (ret)
			goto close;
		audit_print_status(&status);

		if (mode == MODE_AUDIT_LOGS)
			ret = audit_dump_logs(fd, status.backlog_limit);
	} else if (mode == MODE_AUDIT_WATCH) {
			ret = audit_dump_logs(fd, 0);
	} else if (mode == MODE_AA_FEATURE_DUMP) {
		ret = aa_dump_features();
	} else if (mode == MODE_AUDIT_SET_BL) {
		ret = aa_set_bl(fd, bl_arg);
	} else {
		print_usage();
	}

close:
	close(fd);
out:
	return ret;
}