/* * Copyright (c) 2011 Qualcomm Atheros, Inc. * */

#include <linux/in.h>
#include <linux/ip.h>
#include <linux/udp.h>
#include <linux/tcp.h>
#include <linux/icmp.h>
#include <net/ip.h>
#include <linux/if_arp.h>

#include <linux/inetdevice.h>
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/netfilter_arp.h>
#include <linux/netfilter_ipv4/ip_tables.h>
#include <linux/netfilter/xt_multiport.h>
#include <linux/netfilter/xt_iprange.h>
#include <linux/netfilter/nf_conntrack_tcp.h>
#include <net/checksum.h>
#include <net/dsfield.h>
#include <net/route.h>
#include <net/netfilter/nf_nat.h>
#include <net/netfilter/nf_nat_core.h>
#include <net/netfilter/nf_nat_rule.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <linux/module.h>

#include <linux/proc_fs.h>

#include <net/SI/fast_common.h>
#include <net/inet_hashtables.h>
#include <linux/igmp.h>

#include <net/netfilter/nf_conntrack_l4proto.h>

#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_l3proto.h>
#include <net/netfilter/nf_conntrack_l4proto.h>
#include <net/netfilter/nf_conntrack_expect.h>
#include <net/netfilter/nf_conntrack_helper.h>
#include <net/netfilter/nf_conntrack_core.h>
#include <net/netfilter/nf_conntrack_extend.h>
#include <net/netfilter/nf_conntrack_acct.h>
#include <net/netfilter/nf_conntrack_ecache.h>
#include <net/netfilter/nf_conntrack_zones.h>
#include <net/netfilter/nf_conntrack_timestamp.h>
#include <net/netfilter/nf_conntrack_timeout.h>
#include <net/netfilter/nf_nat.h>
#include <net/netfilter/nf_nat_core.h>

MODULE_LICENSE("GPL");

/* ***************** ؿתر ************************* */
int  rcv_fast_threshold = 10;
int  send_fast_threshold = 10;
EXPORT_SYMBOL(send_fast_threshold);

/* ****************************  ************************ */
extern int tcp_v4_rcv(struct sk_buff *skb);
extern int udp_rcv(struct sk_buff *skb);

extern int tcp_v6_rcv(struct sk_buff *skb);
extern int udpv6_rcv(struct sk_buff *skb);

/* **************************** ʵ ************************ */
int fast_local4_recv(struct nf_conn *tmpl, 
    struct sk_buff *skb, 
    struct nf_conn *ct, 
    struct nf_conntrack_l4proto *l4proto,
    unsigned int dataoff,
    int dir,
    u_int8_t protonum)
{
    struct iphdr *iph = NULL;
    struct udphdr *udph = NULL;
    struct tcphdr *tcph = NULL;
    struct sock *sk = NULL;
    enum ip_conntrack_info ctinfo;
    unsigned int *timeouts = NULL;
   	struct nf_conn_timeout *timeout_ext = NULL;
    int ret;
    
    if (skb->dev == NULL)
        panic("fast_local4_recv skb->dev == NULL");

    if (skb->pkt_type != PACKET_HOST){
        skb->pkt_type = PACKET_HOST;
    }
    rcu_read_lock();

    //tcp_v4_rcvudp_rcvлὫput sk
    spin_lock_bh(&fastlocal_spinlock);    
    sk = rcu_dereference(ct->fast_ct.sk);    
    if(!sk)  
    {   
        spin_unlock_bh(&fastlocal_spinlock);
        //printk("fast_local4_recv no sk error\n");
        goto out;
    }
    sock_hold(sk);
    spin_unlock_bh(&fastlocal_spinlock);

    iph = ip_hdr(skb);
    if (IPPROTO_TCP == iph->protocol && sk->sk_state != TCP_ESTABLISHED)
    {
        sock_put(sk);
        goto out;
    }

    skb->nfct = &ct->ct_general;
    
    if (dir == 1) {
        ctinfo = IP_CT_ESTABLISHED_REPLY;
    } else {
        if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) {
            ctinfo = IP_CT_ESTABLISHED;
        } else if (test_bit(IPS_EXPECTED_BIT, &ct->status)) {
            ctinfo = IP_CT_RELATED;
        } else {
            ctinfo = IP_CT_NEW;
        }
    }    
    skb->nfctinfo = ctinfo;

    timeout_ext = nf_ct_timeout_find(ct);
    if (timeout_ext)
        timeouts = NF_CT_TIMEOUT_EXT_DATA(timeout_ext);
    else
        timeouts = l4proto->get_timeouts(&init_net);


    //¶ʱ״̬
    ret = l4proto->packet(ct, skb, dataoff, ctinfo, PF_INET, NF_INET_PRE_ROUTING, timeouts);
    if (ret <= 0) {
        sock_put(sk);
        skb->nfct = NULL;
        printk("fast_local4_recv l4proto->packet error\n");
        goto out;
    }

    fast_tcpdump(skb);
    
    //ץУݻclonefastɹҪıdataݣҪcopyһ
    if (skb_cloned(skb))
    {
        if (pskb_expand_head(skb, 0, 0, GFP_ATOMIC))
        {    
            //print_sun(SUN_DBG, "fast_local4_recv clone copy failed !!!\n");
            kfree_skb(skb);
            nf_conntrack_put(&ct->ct_general);
            printk("fast_local4_recv pskb_expand_head skb failed, free skb\n");
            if (tmpl)
                nf_conntrack_put(&tmpl->ct_general);
            return 1;
        }
    }
    
    skb->isFastlocal = 1;

    //ctӵͳ
    ct->packet_info[dir].packets++;
    ct->packet_info[dir].bytes += skb->len;

    //ڵͳ
    if (fastnat_level == FAST_NET_DEVICE)
    {
        skb->dev->stats.rx_packets++;
        skb->dev->stats.rx_bytes += skb->len;
    }
    
    skb_pull(skb, ip_hdrlen(skb));
    skb_reset_transport_header(skb);

    if (IPPROTO_TCP == iph->protocol)
    {
        ret = tcp_v4_rcv(skb);
    }
    else
    {
        ret = udp_rcv(skb);
    }
    if (tmpl)
        nf_conntrack_put(&tmpl->ct_general);
    rcu_read_unlock();
    return 1;
out:    
    nf_conntrack_put(&ct->ct_general);
    if(tmpl)
    {
        skb->nfct = &tmpl->ct_general;
    }
    else{
        skb->nfct = NULL;
    }
    rcu_read_unlock();
    return 0;
}

int fast_local_output(struct sk_buff *skb)
{    
    struct iphdr *iph = ip_hdr(skb);
    struct nf_conn *ct = NULL, *tmpl = NULL;
    struct sock *sk = skb->sk;
    enum ip_conntrack_info ctinfo;
    struct nf_conntrack_l4proto *l4proto;
    struct nf_conn_timeout *timeout_ext = NULL;
    unsigned int *timeouts;
    unsigned int dataoff;
    u_int8_t protonum;
    int dir = 0;
    int ret;

    //ڿλͼش򿪣ܿعرգκûܿؿƣҪж
    if (fastnat_level == FAST_CLOSE || fastnat_level == FAST_CLOSE_KEEP_LINK)
        return 0;

    //ֻTCPUDPпתpingʱҲƥ䵽TCPUDPctҪʾж
    if (IPPROTO_TCP != iph->protocol && IPPROTO_UDP != iph->protocol)
        return 0;

#if 0
    if(sk->sk_send_sum < send_fast_threshold)
        return 0;
#endif

    //Ƭпת
    if (ip_is_fragment(iph))
        return 0;

    iph->tot_len = htons(skb->len);
    ip_send_check(iph);
    
    ct = skb_get_ct(&tmpl, skb, &l4proto, &dataoff, PF_INET, NF_INET_LOCAL_OUT, &dir, &protonum);
    if (!ct)
        return 0;
    
    rcu_read_lock();
    
    if (dir == 1) {
        ctinfo = IP_CT_ESTABLISHED_REPLY;
    } else {
        if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) {
            ctinfo = IP_CT_ESTABLISHED;
        } else if (test_bit(IPS_EXPECTED_BIT, &ct->status)) {
            ctinfo = IP_CT_RELATED;
        } else {
            ctinfo = IP_CT_NEW;
        }
    }
    
    timeout_ext = nf_ct_timeout_find(ct);
    if (timeout_ext)
        timeouts = NF_CT_TIMEOUT_EXT_DATA(timeout_ext);
    else
        timeouts = l4proto->get_timeouts(&init_net);


    //¶ʱ״̬
    ret = l4proto->packet(ct, skb, dataoff, ctinfo, PF_INET, NF_INET_LOCAL_OUT, timeouts);
    if (ret <= 0) {
        skb->nfct = NULL;
        printk("fast_local_output l4proto->packet error\n");
        goto out;
    }
    
    skb->dev = skb_dst(skb)->dev;
    skb->protocol = htons(ETH_P_IP);
    skb->nfct = &ct->ct_general;
    skb->nfctinfo = ctinfo;
    NF_CT_ASSERT(skb->nfct);
    
    if (tmpl)
        nf_conntrack_put(&tmpl->ct_general);
    rcu_read_unlock();

    //ctӵͳ
    ct->packet_info[dir].packets++;
    ct->packet_info[dir].bytes += skb->len;
    
    return 1;
out:
    nf_conntrack_put(&ct->ct_general);
    rcu_read_unlock();
    if (tmpl)
    {
        skb->nfct = &tmpl->ct_general;
    }
    else{
        skb->nfct = NULL;
    }
    return 0;
}

int fast_local6_recv(struct nf_conn * tmpl, 
    struct sk_buff *skb, 
    struct nf_conn *ct, 
    struct nf_conntrack_l4proto *l4proto,
    unsigned int dataoff,
    int dir,
    u_int8_t protonum)
{    
    enum ip_conntrack_info ctinfo;
    struct sock *sk = NULL;
    struct nf_conn_timeout *timeout_ext = NULL;
    unsigned int *timeouts = NULL;
    int ret;
    
    rcu_read_lock();

    if (skb->dev == NULL)
        panic("fast_local6_recv skb->dev == NULL");

    if (skb->pkt_type != PACKET_HOST){
        skb->pkt_type = PACKET_HOST;
    }

    //tcp_v6_rcvudpv6_rcvлὫput sk
    spin_lock_bh(&fastlocal_spinlock);      
    sk = rcu_dereference(ct->fast_ct.sk);     
    if(!sk)  
    {    
        spin_unlock_bh(&fastlocal_spinlock);
        //printk("fast_local6_recv no sk error\n");
        goto out;
    }
    sock_hold(sk);
    spin_unlock_bh(&fastlocal_spinlock);
    
    if (NEXTHDR_TCP == protonum && sk->sk_state != TCP_ESTABLISHED)
    {
        sock_put(sk);
        goto out;
    }
    
    if (dir == 1) {
        ctinfo = IP_CT_ESTABLISHED_REPLY;
    } else {
        if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) {
            ctinfo = IP_CT_ESTABLISHED;
        } else if (test_bit(IPS_EXPECTED_BIT, &ct->status)) {
            ctinfo = IP_CT_RELATED;
        } else {
            ctinfo = IP_CT_NEW;
        }
    }
    
    timeout_ext = nf_ct_timeout_find(ct);
    if (timeout_ext)
        timeouts = NF_CT_TIMEOUT_EXT_DATA(timeout_ext);
    else
        timeouts = l4proto->get_timeouts(&init_net);

    
    //¶ʱ״̬
    ret = l4proto->packet(ct, skb, dataoff, ctinfo, PF_INET6, NF_INET_PRE_ROUTING, timeouts);
    if (ret <= 0) {
        sock_put(sk);
        skb->nfct = NULL;
        printk("fast_local6_recv l4proto->packet error\n");
        goto out;
    }

    fast_tcpdump(skb);
    
    //ץУݻclonefastɹҪıdataݣҪcopyһ
    if (skb_cloned(skb))
    {
        if (pskb_expand_head(skb, 0, 0, GFP_ATOMIC))
        {    
            //print_sun(SUN_DBG, "fast_local6_recv clone copy failed !!!\n");
            kfree_skb(skb);
            nf_conntrack_put(&ct->ct_general);
            printk("fast_local6_recv pskb_expand_head skb failed, free skb\n");
            if (tmpl)
                nf_conntrack_put(&tmpl->ct_general);
            return 1;
        }
    }
    
    skb->nfct = &ct->ct_general;
    skb->nfctinfo = ctinfo;
    skb->isFastlocal = 1;

    //ctӵͳ
    ct->packet_info[dir].packets++;
    ct->packet_info[dir].bytes += skb->len;
    
    //ڵͳ
    if (fastnat_level == FAST_NET_DEVICE)
    {
        skb->dev->stats.rx_packets++;
        skb->dev->stats.rx_bytes += skb->len;
    }
    
    skb_pull(skb, dataoff);
    skb_reset_transport_header(skb);

    if (NEXTHDR_TCP == protonum)
    {                 
        ret = tcp_v6_rcv(skb);
    }
    else
    {
        ret = udpv6_rcv(skb);
    }
    if (tmpl)
        nf_conntrack_put(&tmpl->ct_general);
    rcu_read_unlock();
    return 1;
out:    
    nf_conntrack_put(&ct->ct_general);
    if (tmpl){
        skb->nfct = &tmpl->ct_general;
    }
    else{
        skb->nfct = NULL;
    }
    rcu_read_unlock();
    return 0;    

}

int fast_local_output_v6(struct sk_buff *skb)
{    
    struct ipv6hdr *iph6 = ipv6_hdr(skb);
    struct nf_conn *ct = NULL, *tmpl = NULL;
    struct sock *sk = skb->sk;
    enum ip_conntrack_info ctinfo;
    struct nf_conntrack_l4proto *l4proto;
    unsigned int *timeouts;
    struct nf_conn_timeout *timeout_ext = NULL;
    unsigned int dataoff;
    u_int8_t protonum;
    int dir = 0;
    int type;
    int ret;
    int len;

    //ڿλͼش򿪣ܿعرգκûܿؿƣҪж
    if (fastnat_level == FAST_CLOSE || fastnat_level == FAST_CLOSE_KEEP_LINK)
        return 0;

    //Ƭпת
    if (skb->nfct_reasm)
        return 0;
    
    len = skb->len - sizeof(struct ipv6hdr);
    if (len > IPV6_MAXPLEN)
        len = 0;
    
    iph6->payload_len = htons(len);
    
    ct = skb_get_ct(&tmpl, skb, &l4proto, &dataoff, PF_INET6, NF_INET_LOCAL_OUT, &dir, &protonum);
    if(!ct) 
        return 0;

    //ֻTCPUDPпתpingʱҲƥ䵽TCPUDPctҪʾж
    if(IPPROTO_TCP != protonum && IPPROTO_UDP != protonum)
        return 0;
    
    rcu_read_lock();
    //spin_lock_bh(&fastlocal_spinlock);       
    
    if (dir == 1) {
        ctinfo = IP_CT_ESTABLISHED_REPLY;
    } else {
        if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) {
            ctinfo = IP_CT_ESTABLISHED;
        } else if (test_bit(IPS_EXPECTED_BIT, &ct->status)) {
            ctinfo = IP_CT_RELATED;
        } else {
            ctinfo = IP_CT_NEW;
        }
    }

    timeout_ext = nf_ct_timeout_find(ct);
    if (timeout_ext)
        timeouts = NF_CT_TIMEOUT_EXT_DATA(timeout_ext);
    else
        timeouts = l4proto->get_timeouts(&init_net);


    //¶ʱ״̬
    ret = l4proto->packet(ct, skb, dataoff, ctinfo, PF_INET6, NF_INET_LOCAL_OUT, timeouts);
    if (ret <= 0) {
        skb->nfct = NULL;
        //spin_unlock_bh(&fastlocal_spinlock); 
        printk("fast_local_output l4proto->packet error\n");
        goto out;
    }
    
    skb->dev = skb_dst(skb)->dev;
    skb->protocol = htons(ETH_P_IPV6);
    skb->nfct = &ct->ct_general;
    skb->nfctinfo = ctinfo;
    NF_CT_ASSERT(skb->nfct);

    if (tmpl)
        nf_conntrack_put(&tmpl->ct_general);
    //spin_unlock_bh(&fastlocal_spinlock);   
    rcu_read_unlock();

    //ctӵͳ
    ct->packet_info[dir].packets++;
    ct->packet_info[dir].bytes += skb->len;
    return 1;
out:
    nf_conntrack_put(&ct->ct_general);
    rcu_read_unlock();
    if(tmpl)
    {
        skb->nfct = &tmpl->ct_general;
    }
    else{
        skb->nfct = NULL;
    }
    return 0;
        
}

//connͷ֪֪ͨͨͷűfast
void fast_local_conn_release(struct nf_conn *ct)
{
    struct sock *sk;
    struct conn_list *entry;

    if (!(sk = ct->fast_ct.sk))
        return;
    
    if (sk->conn_list_init == 0)
        goto out;
    
     list_for_each_entry_rcu(entry, &sk->conn_head, list)
     {
        if (entry->nfct == ct)
        {    
            __list_del_entry(&entry->list);
            kfree(entry);
            break;
        }
    }
    
out:
    rcu_assign_pointer(ct->fast_ct.sk, NULL);
    ct->fast_ct.isFast = 0;
}

//socketͷ֪֪ͨͨͷűfast socket
void fast_local_sock_release(struct sock *sk)
{
    struct conn_list* entry;
    struct conn_list* entry_temp;
    struct nf_conn *ct;
    struct list_head *tmp;
    
    if (sk->conn_list_init == 0)
        return;

    list_for_each_entry_safe(entry, entry_temp, &sk->conn_head, list){
        rcu_assign_pointer(ct, entry->nfct);
        if (ct && ct->fast_ct.sk && ct->fast_ct.sk == sk)
        {
            rcu_assign_pointer(ct->fast_ct.sk, NULL);
            ct->fast_ct.isFast = 0;
        }
        entry->nfct = NULL;
        __list_del_entry(&entry->list);
        kfree(entry);
    }
}


static int __init
tsp_fastlocal_init(void)
{
    spin_lock_init(&fastlocal_spinlock);
    return 0;
}
static void __exit
tsp_fastlocal_cleanup(void)
{
//    return 0;
}

late_initcall(tsp_fastlocal_init);
module_exit(tsp_fastlocal_cleanup);


