/*
 * Userland tool for the unicache.
 *
 * Copyright (C) 2006 Robert Olsson <Robert.Olsson@its.uu.se>
 *                    Uppsala, Sweden
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <getopt.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <linux/netlink.h>
#include "unicache.h"
#include "trie_core.h"

/* Sync with include/linux/netlink.h */
#define NETLINK_UNICACHE        20

#define NL_MAX_PAYLOAD 1024  

static void print_msg(__u32 *data, int ne)
{
        int i;

        for(i = 0; i < ne; i++) 
		printf("%08x ", data[i]);
        printf("\n"); 
}

static struct option long_options[] = {
	{"set_gc_thresh", 1, 0, UNICACHE_SET_GC_THRESH},
	{"get_gc_thresh", 0, 0, UNICACHE_GET_GC_THRESH},
	{"set_gc_goal", 1, 0, UNICACHE_SET_GC_GOAL},
	{"get_gc_goal", 0, 0, UNICACHE_GET_GC_GOAL},
	{"set_gc_level", 1, 0, UNICACHE_SET_GC_LEVEL},
	{"get_gc_level", 0, 0, UNICACHE_GET_GC_LEVEL},
	{"set_log_mask", 1, 0, UNICACHE_SET_LOG_MASK},
	{"set_log_sample_rate", 1, 0, UNICACHE_SET_LOG_SAMPLE},
	{"set_timestamp_flow_rate", 1, 0, UNICACHE_SET_TIMESTAMP_FLOW},
	{"get_timestamp_flow_rate", 0, 0, UNICACHE_GET_TIMESTAMP_FLOW},
	{"get_log_mask", 0, 0, UNICACHE_SET_LOG_MASK},
	{"log_flow_start", 0, 0, UNICACHE_LOG_FLOW_START},
	{"log_flow_end", 0, 0, UNICACHE_LOG_FLOW_END},
	{"insert", 1, 0, UNICACHE_INSERT_FLOW},
	{"remove", 0, 0, UNICACHE_REMOVE_FLOW},
	{"flush", 0, 0, UNICACHE_FLUSH},
	{"help", 0, 0, UNICACHE_MAX+1},
	{0, 0, 0, 0}
};

int msend_generic(int s, __u32 groups, __u16 type, void *obj, size_t obj_size)
{
	struct msghdr msg;
	struct nlmsghdr *nlh = NULL;
	struct iovec iov;
	struct sockaddr_nl dst_addr; 
	int err;

	if(obj_size > NL_MAX_PAYLOAD) {
		perror("obj_size");
		return -1;
	}
		
	nlh = (struct nlmsghdr *)malloc(NLMSG_SPACE(NL_MAX_PAYLOAD));

	if(! nlh) {
		perror("malloc");
		return errno;
	}
	
	memset(&dst_addr, 0, sizeof(dst_addr)); 
	dst_addr.nl_family = AF_NETLINK;
        dst_addr.nl_pid = 0;
	dst_addr.nl_groups = groups;

	memset(nlh, 0, NLMSG_SPACE(NL_MAX_PAYLOAD));      

        nlh->nlmsg_len = NLMSG_SPACE(NL_MAX_PAYLOAD);
        nlh->nlmsg_pid = 0; 
        nlh->nlmsg_type = type;
        nlh->nlmsg_flags = 0;
        
	memcpy(NLMSG_DATA(nlh), obj, obj_size);	

	memset(&iov, 0, sizeof(struct iovec));
	iov.iov_base = (void *)nlh;
//	iov.iov_len = NLMSG_SPACE(NL_MAX_PAYLOAD);
	iov.iov_len =  nlh->nlmsg_len;

	memset(&msg, 0, sizeof(struct msghdr));
 	msg.msg_name = (void *)&dst_addr;
	msg.msg_namelen = sizeof(struct sockaddr_nl);
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;

	err = sendmsg(s, &msg, 0);

	if(err < 0 )
		perror("sendmsg");

	free(nlh);
	return 0;
}

int generic(int s, __u32 groups, __u16 type, int *obj, size_t obj_size)
{
	struct msghdr msg;
	struct nlmsghdr *nlh = NULL;
	struct iovec iov;
	struct sockaddr_nl dst_addr; 
	int err;

	if(obj_size > NL_MAX_PAYLOAD) {
		perror("obj_size");
		return -1;
	}
		
	nlh = (struct nlmsghdr *)malloc(NLMSG_SPACE(NL_MAX_PAYLOAD));

	if(! nlh) {
		perror("malloc");
		return errno;
	}
	
	memset(&dst_addr, 0, sizeof(dst_addr)); 
	dst_addr.nl_family = AF_NETLINK;
        dst_addr.nl_pid = 0;
	dst_addr.nl_groups = groups;

	memset(nlh, 0, NLMSG_SPACE(NL_MAX_PAYLOAD));      

        nlh->nlmsg_len = NLMSG_SPACE(NL_MAX_PAYLOAD);
        nlh->nlmsg_pid = 0; 
        nlh->nlmsg_type = type;
        nlh->nlmsg_flags = 0;
        
	memcpy(NLMSG_DATA(nlh), obj, obj_size);	

	memset(&iov, 0, sizeof(struct iovec));
	iov.iov_base = (void *)nlh;
//	iov.iov_len = NLMSG_SPACE(NL_MAX_PAYLOAD);
	iov.iov_len =  nlh->nlmsg_len;

	memset(&msg, 0, sizeof(struct msghdr));
 	msg.msg_name = (void *)&dst_addr;
	msg.msg_namelen = sizeof(struct sockaddr_nl);
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;

	err = sendmsg(s, &msg, 0);

	if(err < 0 )
		perror("sendmsg");

	/* Read message from kernel */
	memset(nlh, 0, NLMSG_SPACE(NL_MAX_PAYLOAD));
	recvmsg(s, &msg, 0);

	memcpy(obj, NLMSG_DATA(nlh), obj_size);	

	free(nlh);
	return 0;
}
	
void dump_flow(int s, struct sockaddr_nl *src_sddr)
{
	struct iovec iov;
	struct msghdr msg;
	struct nlmsghdr *nlh = NULL;
	int i;

	int size;
	__u32 *data;

	nlh = (struct nlmsghdr *)malloc(
		NLMSG_SPACE(NL_MAX_PAYLOAD));

	memset(nlh, 0, NLMSG_SPACE(NL_MAX_PAYLOAD));      
        nlh->nlmsg_len = NLMSG_SPACE(NL_MAX_PAYLOAD);
        nlh->nlmsg_pid = 0; 
        nlh->nlmsg_flags = 0;
        
	memset(&iov, 0, sizeof(struct iovec));
	iov.iov_base = (void *)nlh;
	iov.iov_len = NLMSG_SPACE(NL_MAX_PAYLOAD);
	msg.msg_iov = &iov;
	msg.msg_iovlen = 1;

	while(1) {
		size = recvmsg(s, &msg, 0);
		i = (nlh->nlmsg_len - sizeof(struct  nlmsghdr))/sizeof(__u32);
		data =  NLMSG_DATA(nlh);
		printf("%08x ", nlh->nlmsg_seq);
		print_msg(data, i);
	}
}

void insert_flow(int s, __u32 *key)
{
	msend_generic(s, UNICACHE_GRP_IPV4, UNICACHE_INSERT_FLOW, key, sizeof(__u32)*LPK);
}

void do_flush(int s)
{
	int key;
	msend_generic(s, UNICACHE_GRP_IPV4, UNICACHE_FLUSH, &key, sizeof(key));
}

void usage(char **av)
{
	printf("\n%s version=%s LPK=%d\n", av[0], VERSION, LPK);

	printf("Implemented functions:\n\n");

	printf("  --set_gc_thresh NUMBER\n");
	printf("  --set_gc_goal NUMBER\n");
	printf("  --set_gc_level MASK\n");
	printf("  --set_log_sample_rate NUMBER\n");
	printf("  --set_timestamp_flow_rate NUMBER\n");
	printf("\n");
	printf("  --get_gc_thresh\n");
	printf("  --get_gc_goal\n");
	printf("  --get_gc_level\n");
	printf("  --get_timestamp_flow_ratio\n");
	printf("\n");
	printf("  --log_flow_end\n");
	printf("  --flush\n");
	printf("  --insert ...FLOW\n");

	exit(-1);
}

void set_gc_thresh(int s, int thresh)
{
	msend_generic(s, UNICACHE_GRP_IPV4, UNICACHE_SET_GC_THRESH, &thresh, sizeof(int));
}

void set_gc_goal(int s, int goal)
{
	msend_generic(s, UNICACHE_GRP_IPV4, UNICACHE_SET_GC_GOAL, &goal, sizeof(int));
}

void set_gc_level(int s, int level)
{
	msend_generic(s, UNICACHE_GRP_IPV4, UNICACHE_SET_GC_LEVEL, &level, sizeof(int));
}

void set_log_sample(int s, int val)
{
	msend_generic(s, UNICACHE_GRP_IPV4, UNICACHE_SET_LOG_SAMPLE, &val, sizeof(int));
}

void set_timestamp_flow(int s, int val)
{
	msend_generic(s, UNICACHE_GRP_IPV4, UNICACHE_SET_TIMESTAMP_FLOW, &val, sizeof(int));
}

void get_gc_thresh(int s, int *val)
{
	generic(s, UNICACHE_GRP_IPV4, UNICACHE_GET_GC_THRESH, val, sizeof(int));
}

void get_gc_goal(int s, int *val)
{
	generic(s, UNICACHE_GRP_IPV4, UNICACHE_GET_GC_GOAL, val, sizeof(int));
}

void get_gc_level(int s, int *val)
{
	generic(s, UNICACHE_GRP_IPV4, UNICACHE_GET_GC_LEVEL, val, sizeof(int));
}

void get_timestamp_flow(int s, int *val)
{
	generic(s, UNICACHE_GRP_IPV4, UNICACHE_GET_TIMESTAMP_FLOW, val, sizeof(int));
}

int main (int ac, char **av) {

	int i, c, s;
	__u32 key[LPK];

	if(ac < 2)
		usage(av);

	struct sockaddr_nl src_addr;
	if( (s = socket(PF_NETLINK, SOCK_RAW, NETLINK_UNICACHE)) < 0 ) {
		perror("nl_trie_core: socket");
                        exit(-1);
	}

	memset(&src_addr, 0, sizeof(src_addr));
	src_addr.nl_family = AF_NETLINK;       
	src_addr.nl_pid = getpid();
	src_addr.nl_groups = UNICACHE_GRP_IPV4;

	bind(s, (struct sockaddr*)&src_addr, 
	     sizeof(src_addr));

	while (1) {
		int option_idx = 0;

		c = getopt_long (ac, av, "",long_options, &option_idx);
		if (c == -1)
			break;

		switch (c) {

		case UNICACHE_MAX+1:
			usage(av);
			break;

		case UNICACHE_FLUSH:
			do_flush(s);
			break;

		case UNICACHE_SET_GC_THRESH:
			
			if (!optarg)
				usage(av);

			i = strtol(optarg, NULL, 0);
			set_gc_thresh(s, i);
			break;

		case UNICACHE_GET_GC_THRESH:
			get_gc_thresh(s, &i);
			printf("gc_thresh=%d\n", i);
			break;

		case UNICACHE_SET_GC_GOAL:
			
			if (!optarg)
				usage(av);

			i = strtol(optarg, NULL, 0);
			set_gc_goal(s, i);
			break;

		case UNICACHE_GET_GC_GOAL:
			get_gc_goal(s, &i);
			printf("gc_goal=%d\n", i);
			break;

		case UNICACHE_SET_GC_LEVEL:
			
			if (!optarg)
				usage(av);

			i = strtol(optarg, NULL, 0);
			set_gc_level(s, i);
			break;

		case UNICACHE_GET_GC_LEVEL:
			get_gc_level(s, &i);
			printf("gc_level=%d\n", i);
			break;

		case UNICACHE_LOG_FLOW_END:
			dump_flow(s, &src_addr);
			break;

		case UNICACHE_SET_LOG_SAMPLE:
			if (!optarg)
				usage(av);

			i = strtol(optarg, NULL, 0);
			set_log_sample(s, i);
			break;


		case UNICACHE_SET_TIMESTAMP_FLOW:
			if (!optarg)
				usage(av);

			i = strtol(optarg, NULL, 0);
			set_timestamp_flow(s, i);
			break;

		case UNICACHE_GET_TIMESTAMP_FLOW:
			get_timestamp_flow(s, &i);
			printf("timestamp_flow_ratio=%d\n", i);
			break;

		 case UNICACHE_INSERT_FLOW:
			 sscanf(optarg, "%i %i %i %i %i", 
				&key[0], &key[1], &key[2], &key[3], &key[4]);
			 insert_flow(s, key);
			 break;

		default:
			printf ("getopt error: code=%d\n", c);
			if (optarg)
				printf (" with arg %s", optarg);
			printf ("\n");
			break;
		}
	}
	close(s);
	exit (0);
}
