]> Pileus Git - ~andy/linux/blobdiff - net/packet/af_packet.c
packet: set transport header before doing xmit
[~andy/linux] / net / packet / af_packet.c
index 1d6793dbfbae23f8c3f28163c0a67718d489d3ed..83fdd0a87eb6d738b6932c9cbd92d9ca7739a964 100644 (file)
@@ -88,6 +88,7 @@
 #include <linux/virtio_net.h>
 #include <linux/errqueue.h>
 #include <linux/net_tstamp.h>
+#include <net/flow_keys.h>
 
 #ifdef CONFIG_INET
 #include <net/inet_common.h>
@@ -181,6 +182,8 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
 
 struct packet_sock;
 static int tpacket_snd(struct packet_sock *po, struct msghdr *msg);
+static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
+                      struct packet_type *pt, struct net_device *orig_dev);
 
 static void *packet_previous_frame(struct packet_sock *po,
                struct packet_ring_buffer *rb,
@@ -973,11 +976,11 @@ static void *packet_current_rx_frame(struct packet_sock *po,
 
 static void *prb_lookup_block(struct packet_sock *po,
                                     struct packet_ring_buffer *rb,
-                                    unsigned int previous,
+                                    unsigned int idx,
                                     int status)
 {
        struct tpacket_kbdq_core *pkc  = GET_PBDQC_FROM_RB(rb);
-       struct tpacket_block_desc *pbd = GET_PBLOCK_DESC(pkc, previous);
+       struct tpacket_block_desc *pbd = GET_PBLOCK_DESC(pkc, idx);
 
        if (status != BLOCK_STATUS(pbd))
                return NULL;
@@ -1041,6 +1044,29 @@ static void packet_increment_head(struct packet_ring_buffer *buff)
        buff->head = buff->head != buff->frame_max ? buff->head+1 : 0;
 }
 
+static bool packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb)
+{
+       struct sock *sk = &po->sk;
+       bool has_room;
+
+       if (po->prot_hook.func != tpacket_rcv)
+               return (atomic_read(&sk->sk_rmem_alloc) + skb->truesize)
+                       <= sk->sk_rcvbuf;
+
+       spin_lock(&sk->sk_receive_queue.lock);
+       if (po->tp_version == TPACKET_V3)
+               has_room = prb_lookup_block(po, &po->rx_ring,
+                                           po->rx_ring.prb_bdqc.kactive_blk_num,
+                                           TP_STATUS_KERNEL);
+       else
+               has_room = packet_lookup_frame(po, &po->rx_ring,
+                                              po->rx_ring.head,
+                                              TP_STATUS_KERNEL);
+       spin_unlock(&sk->sk_receive_queue.lock);
+
+       return has_room;
+}
+
 static void packet_sock_destruct(struct sock *sk)
 {
        skb_queue_purge(&sk->sk_error_queue);
@@ -1066,16 +1092,16 @@ static int fanout_rr_next(struct packet_fanout *f, unsigned int num)
        return x;
 }
 
-static struct sock *fanout_demux_hash(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+static unsigned int fanout_demux_hash(struct packet_fanout *f,
+                                     struct sk_buff *skb,
+                                     unsigned int num)
 {
-       u32 idx, hash = skb->rxhash;
-
-       idx = ((u64)hash * num) >> 32;
-
-       return f->arr[idx];
+       return (((u64)skb->rxhash) * num) >> 32;
 }
 
-static struct sock *fanout_demux_lb(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+static unsigned int fanout_demux_lb(struct packet_fanout *f,
+                                   struct sk_buff *skb,
+                                   unsigned int num)
 {
        int cur, old;
 
@@ -1083,14 +1109,40 @@ static struct sock *fanout_demux_lb(struct packet_fanout *f, struct sk_buff *skb
        while ((old = atomic_cmpxchg(&f->rr_cur, cur,
                                     fanout_rr_next(f, num))) != cur)
                cur = old;
-       return f->arr[cur];
+       return cur;
+}
+
+static unsigned int fanout_demux_cpu(struct packet_fanout *f,
+                                    struct sk_buff *skb,
+                                    unsigned int num)
+{
+       return smp_processor_id() % num;
 }
 
-static struct sock *fanout_demux_cpu(struct packet_fanout *f, struct sk_buff *skb, unsigned int num)
+static unsigned int fanout_demux_rollover(struct packet_fanout *f,
+                                         struct sk_buff *skb,
+                                         unsigned int idx, unsigned int skip,
+                                         unsigned int num)
 {
-       unsigned int cpu = smp_processor_id();
+       unsigned int i, j;
 
-       return f->arr[cpu % num];
+       i = j = min_t(int, f->next[idx], num - 1);
+       do {
+               if (i != skip && packet_rcv_has_room(pkt_sk(f->arr[i]), skb)) {
+                       if (i != j)
+                               f->next[idx] = i;
+                       return i;
+               }
+               if (++i == num)
+                       i = 0;
+       } while (i != j);
+
+       return idx;
+}
+
+static bool fanout_has_flag(struct packet_fanout *f, u16 flag)
+{
+       return f->flags & (flag >> 8);
 }
 
 static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
@@ -1099,7 +1151,7 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
        struct packet_fanout *f = pt->af_packet_priv;
        unsigned int num = f->num_members;
        struct packet_sock *po;
-       struct sock *sk;
+       unsigned int idx;
 
        if (!net_eq(dev_net(dev), read_pnet(&f->net)) ||
            !num) {
@@ -1110,23 +1162,31 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
        switch (f->type) {
        case PACKET_FANOUT_HASH:
        default:
-               if (f->defrag) {
+               if (fanout_has_flag(f, PACKET_FANOUT_FLAG_DEFRAG)) {
                        skb = ip_check_defrag(skb, IP_DEFRAG_AF_PACKET);
                        if (!skb)
                                return 0;
                }
                skb_get_rxhash(skb);
-               sk = fanout_demux_hash(f, skb, num);
+               idx = fanout_demux_hash(f, skb, num);
                break;
        case PACKET_FANOUT_LB:
-               sk = fanout_demux_lb(f, skb, num);
+               idx = fanout_demux_lb(f, skb, num);
                break;
        case PACKET_FANOUT_CPU:
-               sk = fanout_demux_cpu(f, skb, num);
+               idx = fanout_demux_cpu(f, skb, num);
+               break;
+       case PACKET_FANOUT_ROLLOVER:
+               idx = fanout_demux_rollover(f, skb, 0, (unsigned int) -1, num);
                break;
        }
 
-       po = pkt_sk(sk);
+       po = pkt_sk(f->arr[idx]);
+       if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER) &&
+           unlikely(!packet_rcv_has_room(po, skb))) {
+               idx = fanout_demux_rollover(f, skb, idx, idx, num);
+               po = pkt_sk(f->arr[idx]);
+       }
 
        return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev);
 }
@@ -1175,10 +1235,13 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
        struct packet_sock *po = pkt_sk(sk);
        struct packet_fanout *f, *match;
        u8 type = type_flags & 0xff;
-       u8 defrag = (type_flags & PACKET_FANOUT_FLAG_DEFRAG) ? 1 : 0;
+       u8 flags = type_flags >> 8;
        int err;
 
        switch (type) {
+       case PACKET_FANOUT_ROLLOVER:
+               if (type_flags & PACKET_FANOUT_FLAG_ROLLOVER)
+                       return -EINVAL;
        case PACKET_FANOUT_HASH:
        case PACKET_FANOUT_LB:
        case PACKET_FANOUT_CPU:
@@ -1203,7 +1266,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                }
        }
        err = -EINVAL;
-       if (match && match->defrag != defrag)
+       if (match && match->flags != flags)
                goto out;
        if (!match) {
                err = -ENOMEM;
@@ -1213,7 +1276,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
                write_pnet(&match->net, sock_net(sk));
                match->id = id;
                match->type = type;
-               match->defrag = defrag;
+               match->flags = flags;
                atomic_set(&match->rr_cur, 0);
                INIT_LIST_HEAD(&match->list);
                spin_lock_init(&match->lock);
@@ -1350,6 +1413,7 @@ static int packet_sendmsg_spkt(struct kiocb *iocb, struct socket *sock,
        __be16 proto = 0;
        int err;
        int extra_len = 0;
+       struct flow_keys keys;
 
        /*
         *      Get and verify the address.
@@ -1450,6 +1514,11 @@ retry:
        if (unlikely(extra_len == 4))
                skb->no_fcs = 1;
 
+       if (skb_flow_dissect(skb, &keys))
+               skb_set_transport_header(skb, keys.thoff);
+       else
+               skb_reset_transport_header(skb);
+
        dev_queue_xmit(skb);
        rcu_read_unlock();
        return len;
@@ -1856,6 +1925,7 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
        struct page *page;
        void *data;
        int err;
+       struct flow_keys keys;
 
        ph.raw = frame;
 
@@ -1881,6 +1951,11 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
        skb_reserve(skb, hlen);
        skb_reset_network_header(skb);
 
+       if (skb_flow_dissect(skb, &keys))
+               skb_set_transport_header(skb, keys.thoff);
+       else
+               skb_reset_transport_header(skb);
+
        if (po->tp_tx_has_off) {
                int off_min, off_max, off;
                off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll);
@@ -2137,6 +2212,7 @@ static int packet_snd(struct socket *sock,
        unsigned short gso_type = 0;
        int hlen, tlen;
        int extra_len = 0;
+       struct flow_keys keys;
 
        /*
         *      Get and verify the address.
@@ -2289,6 +2365,13 @@ static int packet_snd(struct socket *sock,
                len += vnet_hdr_len;
        }
 
+       if (skb->ip_summed == CHECKSUM_PARTIAL)
+               skb_set_transport_header(skb, skb_checksum_start_offset(skb));
+       else if (skb_flow_dissect(skb, &keys))
+               skb_set_transport_header(skb, keys.thoff);
+       else
+               skb_set_transport_header(skb, reserve);
+
        if (unlikely(extra_len == 4))
                skb->no_fcs = 1;
 
@@ -3240,7 +3323,8 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
        case PACKET_FANOUT:
                val = (po->fanout ?
                       ((u32)po->fanout->id |
-                       ((u32)po->fanout->type << 16)) :
+                       ((u32)po->fanout->type << 16) |
+                       ((u32)po->fanout->flags << 24)) :
                       0);
                break;
        case PACKET_TX_HAS_OFF: