]> Pileus Git - ~andy/linux/blobdiff - drivers/vhost/net.c
vhost: fix total length when packets are too short
[~andy/linux] / drivers / vhost / net.c
index 9a68409580d5b76d8d3972e42a33314dc22f16f8..026be580d318481f6ed3732dd06b6fb2ea2c5160 100644 (file)
@@ -70,7 +70,12 @@ enum {
 };
 
 struct vhost_net_ubuf_ref {
-       struct kref kref;
+       /* refcount follows semantics similar to kref:
+        *  0: object is released
+        *  1: no outstanding ubufs
+        * >1: outstanding ubufs
+        */
+       atomic_t refcount;
        wait_queue_head_t wait;
        struct vhost_virtqueue *vq;
 };
@@ -116,14 +121,6 @@ static void vhost_net_enable_zcopy(int vq)
        vhost_net_zcopy_mask |= 0x1 << vq;
 }
 
-static void vhost_net_zerocopy_done_signal(struct kref *kref)
-{
-       struct vhost_net_ubuf_ref *ubufs;
-
-       ubufs = container_of(kref, struct vhost_net_ubuf_ref, kref);
-       wake_up(&ubufs->wait);
-}
-
 static struct vhost_net_ubuf_ref *
 vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
 {
@@ -134,21 +131,24 @@ vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
        ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
        if (!ubufs)
                return ERR_PTR(-ENOMEM);
-       kref_init(&ubufs->kref);
+       atomic_set(&ubufs->refcount, 1);
        init_waitqueue_head(&ubufs->wait);
        ubufs->vq = vq;
        return ubufs;
 }
 
-static void vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
+static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
 {
-       kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal);
+       int r = atomic_sub_return(1, &ubufs->refcount);
+       if (unlikely(!r))
+               wake_up(&ubufs->wait);
+       return r;
 }
 
 static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
 {
-       kref_put(&ubufs->kref, vhost_net_zerocopy_done_signal);
-       wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount));
+       vhost_net_ubuf_put(ubufs);
+       wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
 }
 
 static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
@@ -306,23 +306,26 @@ static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
 {
        struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
        struct vhost_virtqueue *vq = ubufs->vq;
-       int cnt = atomic_read(&ubufs->kref.refcount);
+       int cnt;
+
+       rcu_read_lock_bh();
 
        /* set len to mark this desc buffers done DMA */
        vq->heads[ubuf->desc].len = success ?
                VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
-       vhost_net_ubuf_put(ubufs);
+       cnt = vhost_net_ubuf_put(ubufs);
 
        /*
         * Trigger polling thread if guest stopped submitting new buffers:
-        * in this case, the refcount after decrement will eventually reach 1
-        * so here it is 2.
+        * in this case, the refcount after decrement will eventually reach 1.
         * We also trigger polling periodically after each 16 packets
         * (the value 16 here is more or less arbitrary, it's tuned to trigger
         * less than 10% of times).
         */
-       if (cnt <= 2 || !(cnt % 16))
+       if (cnt <= 1 || !(cnt % 16))
                vhost_poll_queue(&vq->poll);
+
+       rcu_read_unlock_bh();
 }
 
 /* Expects to be always run from workqueue - which acts as
@@ -420,7 +423,7 @@ static void handle_tx(struct vhost_net *net)
                        msg.msg_control = ubuf;
                        msg.msg_controllen = sizeof(ubuf);
                        ubufs = nvq->ubufs;
-                       kref_get(&ubufs->kref);
+                       atomic_inc(&ubufs->refcount);
                        nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
                } else {
                        msg.msg_control = NULL;
@@ -529,6 +532,12 @@ static int get_rx_bufs(struct vhost_virtqueue *vq,
        *iovcount = seg;
        if (unlikely(log))
                *log_num = nlogs;
+
+       /* Detect overrun */
+       if (unlikely(datalen > 0)) {
+               r = UIO_MAXIOV + 1;
+               goto err;
+       }
        return headcount;
 err:
        vhost_discard_vq_desc(vq, headcount);
@@ -584,6 +593,14 @@ static void handle_rx(struct vhost_net *net)
                /* On error, stop handling until the next kick. */
                if (unlikely(headcount < 0))
                        break;
+               /* On overrun, truncate and discard */
+               if (unlikely(headcount > UIO_MAXIOV)) {
+                       msg.msg_iovlen = 1;
+                       err = sock->ops->recvmsg(NULL, sock, &msg,
+                                                1, MSG_DONTWAIT | MSG_TRUNC);
+                       pr_debug("Discarded rx packet: len %zd\n", sock_len);
+                       continue;
+               }
                /* OK, now we need to know about added descriptors. */
                if (!headcount) {
                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
@@ -780,7 +797,7 @@ static void vhost_net_flush(struct vhost_net *n)
                vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
                mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
                n->tx_flush = false;
-               kref_init(&n->vqs[VHOST_NET_VQ_TX].ubufs->kref);
+               atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
                mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
        }
 }
@@ -800,6 +817,8 @@ static int vhost_net_release(struct inode *inode, struct file *f)
                fput(tx_sock->file);
        if (rx_sock)
                fput(rx_sock->file);
+       /* Make sure no callbacks are outstanding */
+       synchronize_rcu_bh();
        /* We do an extra flush before freeing memory,
         * since jobs can re-queue themselves. */
        vhost_net_flush(n);