
#include "portmirroring.h"

MODULE_LICENSE("GPL");

//spinlock_t mapping_factor_spinlock;
//mapping_factor_t mapping_factor;

/* TCP flags */
#define TH_FIN  0x0001
#define TH_SYN  0x0002
#define TH_RST  0x0004
#define TH_PUSH 0x0008
#define TH_ACK  0x0010
#define TH_URG  0x0020
#define TH_ECN  0x0040
#define TH_CWR  0x0080
#define TH_NS   0x0100
#define TH_RES  0x0E00 /* 3 reserved bits */
#define TH_MASK 0x0FFF

int do_it=0;
module_param(do_it, int, 0644);

char *dev_name_from;
module_param(dev_name_from, charp, 0644);
char *dev_name_to;
module_param(dev_name_to, charp, 0644);


char *mac_src;
module_param(mac_src, charp, 0644);
char *mac_dst;
module_param(mac_dst, charp, 0644);
char *ether_proto;
module_param(ether_proto, charp, 0644);
char *ip_proto;
module_param(ip_proto, charp, 0644);
char *ip_src;
module_param(ip_src, charp, 0644);
char *ip_dst;
module_param(ip_dst, charp, 0644);
int is_frag = -1;
module_param(is_frag, int, 0644);
ushort port_src = 0xFFFF;
module_param(port_src, ushort, 0644);
ushort port_dst = 0xFFFF;
module_param(port_dst, ushort, 0644);
uint tcp_seq = 0xFFFFFFFF;
module_param(tcp_seq, uint, 0644);
uint tcp_ack = 0xFFFFFFFF;
module_param(tcp_ack, uint, 0644);
char *tcp_flags;
module_param(tcp_flags, charp, 0644);
ushort tcp_window = 0xFFFF;
module_param(tcp_window, ushort, 0644);
ushort checksum = 0xFFFF;
module_param(checksum, ushort, 0644);

#define UNCOVER_COUNT 4

struct _uncover
{
    int (*check_func)(void);
    int need_uncover;
    int (*process_func)(struct sk_buff *skb);
};

int check_uncover_ether_hdr();
int check_uncover_ip_hdr();
int check_uncover_tcp_hdr();
int check_uncover_udp_hdr();
int uncover_tcp_hdr(struct sk_buff *skb);
int uncover_udp_hdr(struct sk_buff *skb);
int uncover_ip_hdr(struct sk_buff *skb);
int uncover_ether_hdr(struct sk_buff *skb);

struct _uncover uncover_array[UNCOVER_COUNT] =
{
    {check_uncover_ether_hdr, 0, uncover_ether_hdr},
    {check_uncover_ip_hdr, 0, uncover_ip_hdr},
    {check_uncover_tcp_hdr, 0, uncover_tcp_hdr},
    {check_uncover_udp_hdr, 0, uncover_udp_hdr}
};

int check_uncover_ether_hdr()
{
    if(mac_src != NULL && strcmp("", mac_src) != 0)
        return 1;
    if(mac_dst != NULL && strcmp("", mac_dst) != 0)
        return 1;
    if(ether_proto != NULL && strcmp("", ether_proto) != 0)
        return 1;
    return 0;
}
int check_uncover_ip_hdr()
{
    if(ip_proto != NULL && strcmp("", ip_proto) != 0)
        return 1;
    if(ip_src != NULL && strcmp("", ip_src) != 0)
        return 1;
    if(ip_dst != NULL && strcmp("", ip_dst) != 0)
        return 1;
    if(is_frag != -1)
        return 1;
    return 0;
}
int check_uncover_tcp_hdr()
{
    if(ip_proto != NULL && strcmp("tcp", ip_proto) == 0)
    {
        if(port_src != 0xFFFF)
            return 1;
        if(port_dst != 0xFFFF)
            return 1;
        if(checksum != 0xFFFF)
            return 1;
    }
    
    if(tcp_seq != 0xFFFFFFFF)
        return 1;
    if(tcp_ack != 0xFFFFFFFF)
        return 1;
    if(tcp_flags != NULL && strcmp("", tcp_flags) != 0)
        return 1;
    if(tcp_window != 0xFFFF)
        return 1;
    
    return 0;
}
int check_uncover_udp_hdr()
{
    if(ip_proto != NULL && strcmp("udp", ip_proto) == 0)
    {
        if(port_src != 0xFFFF)
            return 1;
        if(port_dst != 0xFFFF)
            return 1;
        if(checksum != 0xFFFF)
            return 1;
    }
    return 0;
}

void check_uncover()
{
    int i;
    for(i = 0; i < UNCOVER_COUNT; i++)
    {
        uncover_array[i].need_uncover = uncover_array[i].check_func();
    }
}

void get_tcp_flags(char *single_flag_str, u_int16_t *flags)
{

    if(strcmp(single_flag_str, "FIN") == 0)
    {
        *flags |= TH_FIN;
    }
    else if(strcmp(single_flag_str, "SYN") == 0)
    {
        *flags |= TH_SYN;
    }
    else if(strcmp(single_flag_str, "RST") == 0)
    {
        *flags |= TH_RST;
    }
    else if(strcmp(single_flag_str, "PSH") == 0)
    {
        *flags |= TH_PUSH;
    }
    else if(strcmp(single_flag_str, "ACK") == 0)
    {
        *flags |= TH_ACK;
    }
    else
    {
        //unknown or not supported
    }

}

int check_tcp_conflict(struct sk_buff *skb)
{
    if(ether_proto != NULL && strcmp("", ether_proto) != 0 && strcmp(ether_proto, "ip") != 0) 
    {
        return 1;
    }

    if(skb->protocol != htons(ETH_P_IP))
    {
        return 1;
    }

    if(ip_proto != NULL && strcmp("", ip_proto) != 0 && strcmp(ether_proto, "tcp") != 0)
    {
        return 1;
    }

    struct iphdr *ip_hdr = (struct iphdr*)skb->data;

    if(ip_hdr->protocol != IPPROTO_TCP)
    {
         return 1;
    }
    
    return 0;
}

int uncover_tcp_hdr(struct sk_buff *skb)
{
    //printk("wanghan uncover_tcp_hdr enter\n");

    if(check_tcp_conflict(skb))
        return 0;
    
    struct iphdr *ip_hdr = (struct iphdr*)skb->data;
    struct tcphdr *tcp_hdr = (struct tcphdr *)(skb->data + ip_hdr->ihl * 4);
    char single_flag_str[8] = {0};
    int i, len, j;
    u_int16_t flags = 0;

    u_int16_t th_flags = 0;
    

    if(port_src != 0xFFFF)
    {
        if(port_src != ntohs(tcp_hdr->source))
        {
            return 0;
        }
    }

    if(port_dst != 0xFFFF)
    {
        if(port_dst != ntohs(tcp_hdr->dest))
        {
            return 0;
        }
    }

    if(tcp_seq != 0xFFFFFFFF)
    {
        if(tcp_seq != ntohl(tcp_hdr->seq))
        {
            return 0;
        }
    }

    if(tcp_ack != 0xFFFFFFFF)
    {
        if(tcp_ack != ntohl(tcp_hdr->ack_seq))
        {
            return 0;
        }
    }

    // tcp_flags is something like "SYN/FIN/RST", flags are separated by "/", which means "and".
    if(tcp_flags != NULL && strcmp("", tcp_flags) != 0) 
    {
        i = j = 0;
        len = strlen(tcp_flags);
        //printk("tcp_flags:%s, len:%d\n", tcp_flags, len);
        while(i < len)
        {
            if(*(tcp_flags + i) == '/')
            {
                get_tcp_flags(single_flag_str, &flags);
                //printk("%s\n", single_flag_str);
                memset(single_flag_str, 0, sizeof(single_flag_str));
                j = 0;
            }
            else 
            {
                single_flag_str[j++] = *(tcp_flags + i);
            }
            i++;
            //j++;
        }
        get_tcp_flags(single_flag_str, &flags);
        //printk("%s\n", single_flag_str);

        memcpy(&th_flags, (char *)tcp_hdr + 12, 2);
        th_flags = ntohs(th_flags) & 0x0FFF;

        //printk("wanghan uncover_tcp_hdr flags:%x, th_flags:%x\n", flags, th_flags);

        if(flags != th_flags)
        {
            return 0;
        }
        
    }


    if(tcp_window != 0xFFFF)
    {
        if(tcp_window != ntohs(tcp_hdr->window))
        {
            return 0;
        }
    }

    if(checksum != 0xFFFF)
    {
        if(checksum != ntohs(tcp_hdr->check))
        {
            return 0;
        }
    }

    return 1;
}

int check_udp_conflict(struct sk_buff *skb)
{
    if(ether_proto != NULL && strcmp("", ether_proto) != 0 && strcmp(ether_proto, "ip") != 0) 
    {
        return 1;
    }

    if(skb->protocol != htons(ETH_P_IP))
    {
        return 1;
    }

    if(ip_proto != NULL && strcmp("", ip_proto) != 0 && strcmp(ether_proto, "udp") != 0)
    {
        return 1;
    }

    struct iphdr *ip_hdr = (struct iphdr*)skb->data;

    if(ip_hdr->protocol != IPPROTO_UDP)
    {
         return 1;
    }
    
    return 0;
}


int uncover_udp_hdr(struct sk_buff *skb)
{
    //printk("wanghan uncover_udp_hdr enter\n");

    if(check_udp_conflict(skb))
        return 0;
    
    struct iphdr *ip_hdr = (struct iphdr*)skb->data;
    struct udphdr *udp_hdr = (struct udphdr *)(skb->data + ip_hdr->ihl * 4);

    if(port_src != 0xFFFF)
    {
        if(port_src != ntohs(udp_hdr->source))
        {
            return 0;
        }
    }

    if(port_dst != 0xFFFF)
    {
        if(port_dst != ntohs(udp_hdr->dest))
        {
            return 0;
        }
    }

    if(checksum != 0xFFFF)
    {
        if(checksum != ntohs(udp_hdr->check))
        {
            return 0;
        }
    }

    return 1;
}

int check_ip_conflict(struct sk_buff *skb)
{
    if(ether_proto != NULL && strcmp("", ether_proto) != 0 && strcmp(ether_proto, "ip") != 0) 
    {
        return 1;
    }

    if(skb->protocol != htons(ETH_P_IP))
    {
        return 1;
    }
    return 0;
}

int uncover_ip_hdr(struct sk_buff *skb)
{
    //printk("wanghan uncover_ip_hdr enter\n");

    if(check_ip_conflict(skb))
        return 0;
    
    struct iphdr *ip_hdr = (struct iphdr*)skb->data;
    u_int32_t saddr;
    u_int32_t daddr;
    uint16_t frag_flags = ntohs(ip_hdr->frag_off);
    //int16_t frag_offset = frag_flags << 3;

    if(ip_proto != NULL && strcmp("", ip_proto) != 0)
    {
        if(strcmp(ip_proto, "tcp") == 0)
        {
            if(ip_hdr->protocol != IPPROTO_TCP)
            {
                return 0;
            }
            //return uncover_tcp_hdr(skb);
        }
        else if(strcmp(ip_proto, "udp") == 0)
        {
            if(ip_hdr->protocol != IPPROTO_UDP)
            {
                return 0;
            }
            //return uncover_udp_hdr(skb);
        }
        else if(strcmp(ip_proto, "icmp") == 0)
        {
            if(ip_hdr->protocol != IPPROTO_ICMP)
            {
                return 0;
            }
        }
        else
        {
            //unknown or not supported
        }
    }

    if(ip_src != NULL && strcmp("", ip_src) != 0)
    {
        saddr = in_aton(ip_src);
        if(saddr != ip_hdr->saddr)
        {
            return 0;
        }
    }

    if(ip_dst != NULL && strcmp("", ip_dst) != 0)
    {
        daddr = in_aton(ip_dst);
        if(daddr != ip_hdr->daddr)
        {
            return 0;
        }
    }

    if(is_frag != -1)
    {
        if(is_frag == 1 && !(frag_flags & (IP_MF|IP_OFFSET))) 
        {
            return 0;
        }
        else if(is_frag == 0 && (frag_flags & (IP_MF|IP_OFFSET))) 
        {
            return 0;
        }
    }

    return 1;
}

u_int32_t wanghan_strtoui_16(char *str, int len) // 16
{
    int i = 0;
    char *cur = str;
    u_int32_t ret = 0;
    while(i < len)
    {
        if(cur[i] >= '0' && cur[i] <= '9')
        {
            ret = ret * 16 + (cur[i] - '0');
        }
        else if(cur[i] >= 'A' && cur[i] <= 'F')
        {
            ret = ret * 16 + (cur[i] - 'A') + 10;
        }
        else if(cur[i] >= 'a' && cur[i] <= 'f')
        {
            ret = ret * 16 + (cur[i] - 'a') + 10;
        }
        else
        {
            break;
        }
        i++;
    }
    return ret;
}

int uncover_ether_hdr(struct sk_buff *skb)
{
    //printk("wanghan uncover_ether_hdr enter\n");
    int i, j, k;
    u_int8_t mac_addr_src[ETH_ALEN] = {0};
    u_int8_t mac_addr_dst[ETH_ALEN] = {0};
    char temp[3] = {0}; 
    struct ethhdr *eth = eth_hdr(skb);
    //char *skb_mac_hdr = skb->data - ETH_HLEN;
    //char *skb_mac_hdr = (char*)eth_hdr(skb);
    // mac address must be something like "1A:2B:3C:4D:5E:6F", other formats are wrong
    if(mac_src != NULL && strcmp("", mac_src) != 0 && strlen(mac_src) != 17) 
    {
        return 0;
    }
    if(mac_dst != NULL && strcmp("", mac_dst) != 0 && strlen(mac_dst) != 17)
    {
        return 0;
    }
    if(mac_src != NULL && strcmp("", mac_src) != 0)
    {
        for(i = 0, j = 0; i < ETH_ALEN; i++, j++)
        {
            memcpy(temp, mac_src + i * 2 + j, 2);
            mac_addr_src[i] = (u_int8_t)wanghan_strtoui_16(temp, 2);
        }

        if(!is_valid_ether_addr(mac_addr_src))
        {
            printk("wanghan uncover_ether_hdr invalid ether addr src\n");
            return 0;
        }
        if(memcmp(eth->h_source, mac_addr_src, ETH_ALEN) != 0)
        {
            //printk("wanghan uncover_ether_hdr mac_src no match\n");
            return 0;
        }
    }

    if(mac_dst != NULL && strcmp("", mac_dst) != 0)
    {
        //printk("wanghan uncover_ether_hdr mac_dst:%s\n", mac_dst);
        for(i = 0, j = 0; i < ETH_ALEN; i++, j++)
        {
            memcpy(temp, mac_dst + i * 2 + j, 2);
            mac_addr_dst[i] = (u_int8_t)wanghan_strtoui_16(temp, 2);
        }
        if(!is_valid_ether_addr(mac_addr_dst))
        {
            printk("wanghan uncover_ether_hdr invalid ether addr dst\n");
            return 0;
        }
        if(memcmp(eth->h_dest, mac_addr_dst, ETH_ALEN) != 0)
        {
            //printk("wanghan uncover_ether_hdr mac_addr_dst no match\n");
            return 0;
        }
    }

    if(ether_proto != NULL && strcmp("", ether_proto) != 0)
    {
        if(strcmp(ether_proto, "ip") == 0)
        {
            if(skb->protocol != htons(ETH_P_IP))
            {
                return 0;
            }
            //return uncover_ip_hdr(skb);
        }
        else if(strcmp(ether_proto, "arp") == 0)
        {
            if(skb->protocol != htons(ETH_P_ARP))
            {
                return 0;
            }
        }
        else if(strcmp(ether_proto, "rarp") == 0)
        {
            if(skb->protocol != htons(ETH_P_RARP))
            {
                return 0;
            }
        }
        else if(strcmp(ether_proto, "vlan") == 0)
        {
            if(skb->protocol != htons(ETH_P_8021Q))
            {
                return 0;
            }
        }
        else if(strcmp(ether_proto, "pppoe") == 0)
        {
            if(skb->protocol != htons(ETH_P_PPP_DISC) && skb->protocol != htons(ETH_P_PPP_SES))
            {
                return 0;
            }
        }
        else
        {
            //unknown or not supported
        }
    }

    return 1;
}

int uncover(struct sk_buff *skb)
{
    //printk("\nwanghan uncover enter\n");
    int ret = 1;
    int i;
    check_uncover();
    for(i = 0; i < UNCOVER_COUNT; i++)
    {
        if(uncover_array[i].need_uncover == 1)
        {
            ret &= uncover_array[i].process_func(skb);
            if(ret == 0)
                goto out;
        }
    }

out:
    return ret;
}

unsigned int mirror_handle(unsigned int hooknum,
            struct sk_buff *skb,
            const struct net_device *in,
            const struct net_device *out, int (*okfn) (struct sk_buff *))
{
    struct sk_buff *skb2;

    struct net_device *dev_from;
    struct net_device *dev_to;

    if(do_it==0)
        return NF_ACCEPT;

    if(dev_name_from==NULL)
    {
        printk("portmirroring, dev_name_from is null\n");
        return NF_ACCEPT;
    }

    if(dev_name_to==NULL)
    {
        printk("portmirroring, dev_name_to is null\n");
        return NF_ACCEPT;
    }

    //printk("portmirroring, dev_name_from:%s, dev_name_to:%s\n",dev_name_from,dev_name_to);

    //dev_from=dev_get_by_name(&init_net,dev_name_from);
    dev_to=dev_get_by_name(&init_net,dev_name_to);

    if(dev_to==NULL)
    {
        printk("portmirroring, dev_to is null\n");
        return NF_ACCEPT;
    }

    if(((in!=NULL && strcmp(in->name,dev_name_from)==0)||
        (out!=NULL && strcmp(out->name,dev_name_from)==0)) && uncover(skb) == 1)
    {
        skb2 = skb_clone(skb, GFP_ATOMIC);

        skb2->dev=dev_to;
        skb_push(skb2,ETH_HLEN);//
        skb2->dev->netdev_ops->ndo_start_xmit(skb2,skb2->dev);
        //dev_queue_xmit(skb2);
    }
    
    //dev_put(dev_from);
    dev_put(dev_to);

    return NF_ACCEPT;
}

static struct nf_hook_ops mirror_br_hook_pre = {
    .hook = mirror_handle,
    .owner = THIS_MODULE,
    .pf = PF_BRIDGE,
    .hooknum = NF_BR_PRE_ROUTING,
    .priority = NF_BR_PRI_FIRST,
};

static struct nf_hook_ops mirror_br_hook_post = {
    .hook = mirror_handle,
    .owner = THIS_MODULE,
    .pf = PF_BRIDGE,
    .hooknum = NF_BR_POST_ROUTING,
    .priority = NF_BR_PRI_FIRST,
};

static struct nf_hook_ops mirror_inet_hook_pre = {
    .hook = mirror_handle,
    .owner = THIS_MODULE,
    .pf = PF_INET,
    .hooknum = NF_INET_PRE_ROUTING,
    .priority = NF_IP_PRI_FIRST,
};

static struct nf_hook_ops mirror_inet_hook_post = {
    .hook = mirror_handle,
    .owner = THIS_MODULE,
    .pf = PF_INET,
    .hooknum = NF_INET_POST_ROUTING,
    .priority = NF_IP_PRI_FIRST,
};

static int __init
tsp_portmirroring_init(void)
{
    int ret;
    
    printk("start init portmirroring\n");

    
    
    ret = nf_register_hook(&mirror_br_hook_pre);
    if (ret != 0)
    {
        printk("init portmirroring failed\n");
        goto err;
    }

    ret = nf_register_hook(&mirror_br_hook_post);
    if (ret != 0)
    {
        printk("init portmirroring failed\n");
        goto err;
    }

    ret = nf_register_hook(&mirror_inet_hook_pre);
    if (ret != 0)
    {
        printk("init portmirroring failed\n");
        goto err;
    }

    ret = nf_register_hook(&mirror_inet_hook_post);
    if (ret != 0)
    {
        printk("init portmirroring failed\n");
        goto err;
    }

    //memset(&mapping_factor, 0, sizeof(_mapping_factor_t));
    //spin_lock_init(&mapping_factor_spinlock);
    
    printk("init portmirroring done\n");

    return 0;
err:
    return 0;
}

static void __exit
tsp_portmirroring_cleanup(void)
{
    nf_unregister_hook(&mirror_br_hook_pre);
    nf_unregister_hook(&mirror_br_hook_post);
    nf_unregister_hook(&mirror_inet_hook_pre);
    nf_unregister_hook(&mirror_inet_hook_post);

    printk("portmirroring cleanup done\n");
}

module_init(tsp_portmirroring_init);
module_exit(tsp_portmirroring_cleanup);
