]> Pileus Git - ~andy/linux/blobdiff - drivers/vhost/net.c
vhost: move per-vq net specific fields out to net
[~andy/linux] / drivers / vhost / net.c
index ec6fb3fa59bb5962281bb6277220941b91df053c..e34e195b9cf63d3291aaf2b9b5c9bbdad393b2df 100644 (file)
@@ -64,20 +64,36 @@ enum {
        VHOST_NET_VQ_MAX = 2,
 };
 
-enum vhost_net_poll_state {
-       VHOST_NET_POLL_DISABLED = 0,
-       VHOST_NET_POLL_STARTED = 1,
-       VHOST_NET_POLL_STOPPED = 2,
+struct vhost_ubuf_ref {
+       struct kref kref;
+       wait_queue_head_t wait;
+       struct vhost_virtqueue *vq;
+};
+
+struct vhost_net_virtqueue {
+       struct vhost_virtqueue vq;
+       /* hdr is used to store the virtio header.
+        * Since each iovec has >= 1 byte length, we never need more than
+        * header length entries to store the header. */
+       struct iovec hdr[sizeof(struct virtio_net_hdr_mrg_rxbuf)];
+       size_t vhost_hlen;
+       size_t sock_hlen;
+       /* vhost zerocopy support fields below: */
+       /* last used idx for outstanding DMA zerocopy buffers */
+       int upend_idx;
+       /* first used idx for DMA done zerocopy buffers */
+       int done_idx;
+       /* an array of userspace buffers info */
+       struct ubuf_info *ubuf_info;
+       /* Reference counting for outstanding ubufs.
+        * Protected by vq mutex. Writers must also take device mutex. */
+       struct vhost_ubuf_ref *ubufs;
 };
 
 struct vhost_net {
        struct vhost_dev dev;
-       struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
+       struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
        struct vhost_poll poll[VHOST_NET_VQ_MAX];
-       /* Tells us whether we are polling a socket for TX.
-        * We only do this when socket buffer fills up.
-        * Protected by tx vq lock. */
-       enum vhost_net_poll_state tx_poll_state;
        /* Number of TX recently submitted.
         * Protected by tx vq lock. */
        unsigned tx_packets;
@@ -88,6 +104,90 @@ struct vhost_net {
        bool tx_flush;
 };
 
+static unsigned vhost_zcopy_mask __read_mostly;
+
+void vhost_enable_zcopy(int vq)
+{
+       vhost_zcopy_mask |= 0x1 << vq;
+}
+
+static void vhost_zerocopy_done_signal(struct kref *kref)
+{
+       struct vhost_ubuf_ref *ubufs = container_of(kref, struct vhost_ubuf_ref,
+                                                   kref);
+       wake_up(&ubufs->wait);
+}
+
+struct vhost_ubuf_ref *vhost_ubuf_alloc(struct vhost_virtqueue *vq,
+                                       bool zcopy)
+{
+       struct vhost_ubuf_ref *ubufs;
+       /* No zero copy backend? Nothing to count. */
+       if (!zcopy)
+               return NULL;
+       ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
+       if (!ubufs)
+               return ERR_PTR(-ENOMEM);
+       kref_init(&ubufs->kref);
+       init_waitqueue_head(&ubufs->wait);
+       ubufs->vq = vq;
+       return ubufs;
+}
+
+void vhost_ubuf_put(struct vhost_ubuf_ref *ubufs)
+{
+       kref_put(&ubufs->kref, vhost_zerocopy_done_signal);
+}
+
+void vhost_ubuf_put_and_wait(struct vhost_ubuf_ref *ubufs)
+{
+       kref_put(&ubufs->kref, vhost_zerocopy_done_signal);
+       wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount));
+       kfree(ubufs);
+}
+
+int vhost_net_set_ubuf_info(struct vhost_net *n)
+{
+       bool zcopy;
+       int i;
+
+       for (i = 0; i < n->dev.nvqs; ++i) {
+               zcopy = vhost_zcopy_mask & (0x1 << i);
+               if (!zcopy)
+                       continue;
+               n->vqs[i].ubuf_info = kmalloc(sizeof(*n->vqs[i].ubuf_info) *
+                                             UIO_MAXIOV, GFP_KERNEL);
+               if  (!n->vqs[i].ubuf_info)
+                       goto err;
+       }
+       return 0;
+
+err:
+       while (i--) {
+               zcopy = vhost_zcopy_mask & (0x1 << i);
+               if (!zcopy)
+                       continue;
+               kfree(n->vqs[i].ubuf_info);
+       }
+       return -ENOMEM;
+}
+
+void vhost_net_vq_reset(struct vhost_net *n)
+{
+       int i;
+
+       for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
+               n->vqs[i].done_idx = 0;
+               n->vqs[i].upend_idx = 0;
+               n->vqs[i].ubufs = NULL;
+               kfree(n->vqs[i].ubuf_info);
+               n->vqs[i].ubuf_info = NULL;
+               n->vqs[i].vhost_hlen = 0;
+               n->vqs[i].sock_hlen = 0;
+       }
+
+}
+
 static void vhost_net_tx_packet(struct vhost_net *net)
 {
        ++net->tx_packets;
@@ -155,28 +255,6 @@ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
        }
 }
 
-/* Caller must have TX VQ lock */
-static void tx_poll_stop(struct vhost_net *net)
-{
-       if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
-               return;
-       vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
-       net->tx_poll_state = VHOST_NET_POLL_STOPPED;
-}
-
-/* Caller must have TX VQ lock */
-static int tx_poll_start(struct vhost_net *net, struct socket *sock)
-{
-       int ret;
-
-       if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
-               return 0;
-       ret = vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
-       if (!ret)
-               net->tx_poll_state = VHOST_NET_POLL_STARTED;
-       return ret;
-}
-
 /* In case of DMA done not in order in lower device driver for some reason.
  * upend_idx is used to track end of used idx, done_idx is used to track head
  * of used idx. Once lower device DMA done contiguously, we will signal KVM
@@ -185,10 +263,12 @@ static int tx_poll_start(struct vhost_net *net, struct socket *sock)
 static int vhost_zerocopy_signal_used(struct vhost_net *net,
                                      struct vhost_virtqueue *vq)
 {
+       struct vhost_net_virtqueue *nvq =
+               container_of(vq, struct vhost_net_virtqueue, vq);
        int i;
        int j = 0;
 
-       for (i = vq->done_idx; i != vq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
+       for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
                if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
                        vhost_net_tx_err(net);
                if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
@@ -200,7 +280,7 @@ static int vhost_zerocopy_signal_used(struct vhost_net *net,
                        break;
        }
        if (j)
-               vq->done_idx = i;
+               nvq->done_idx = i;
        return j;
 }
 
@@ -230,7 +310,8 @@ static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
  * read-size critical section for our kind of RCU. */
 static void handle_tx(struct vhost_net *net)
 {
-       struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
+       struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
+       struct vhost_virtqueue *vq = &nvq->vq;
        unsigned out, in, s;
        int head;
        struct msghdr msg = {
@@ -242,7 +323,7 @@ static void handle_tx(struct vhost_net *net)
                .msg_flags = MSG_DONTWAIT,
        };
        size_t len, total_len = 0;
-       int err, wmem;
+       int err;
        size_t hdr_size;
        struct socket *sock;
        struct vhost_ubuf_ref *uninitialized_var(ubufs);
@@ -253,21 +334,11 @@ static void handle_tx(struct vhost_net *net)
        if (!sock)
                return;
 
-       wmem = atomic_read(&sock->sk->sk_wmem_alloc);
-       if (wmem >= sock->sk->sk_sndbuf) {
-               mutex_lock(&vq->mutex);
-               tx_poll_start(net, sock);
-               mutex_unlock(&vq->mutex);
-               return;
-       }
-
        mutex_lock(&vq->mutex);
        vhost_disable_notify(&net->dev, vq);
 
-       if (wmem < sock->sk->sk_sndbuf / 2)
-               tx_poll_stop(net);
-       hdr_size = vq->vhost_hlen;
-       zcopy = vq->ubufs;
+       hdr_size = nvq->vhost_hlen;
+       zcopy = nvq->ubufs;
 
        for (;;) {
                /* Release DMAs done buffers first */
@@ -285,23 +356,15 @@ static void handle_tx(struct vhost_net *net)
                if (head == vq->num) {
                        int num_pends;
 
-                       wmem = atomic_read(&sock->sk->sk_wmem_alloc);
-                       if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
-                               tx_poll_start(net, sock);
-                               set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
-                               break;
-                       }
                        /* If more outstanding DMAs, queue the work.
                         * Handle upend_idx wrap around
                         */
-                       num_pends = likely(vq->upend_idx >= vq->done_idx) ?
-                                   (vq->upend_idx - vq->done_idx) :
-                                   (vq->upend_idx + UIO_MAXIOV - vq->done_idx);
-                       if (unlikely(num_pends > VHOST_MAX_PEND)) {
-                               tx_poll_start(net, sock);
-                               set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
+                       num_pends = likely(nvq->upend_idx >= nvq->done_idx) ?
+                                   (nvq->upend_idx - nvq->done_idx) :
+                                   (nvq->upend_idx + UIO_MAXIOV -
+                                    nvq->done_idx);
+                       if (unlikely(num_pends > VHOST_MAX_PEND))
                                break;
-                       }
                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
                                vhost_disable_notify(&net->dev, vq);
                                continue;
@@ -314,45 +377,45 @@ static void handle_tx(struct vhost_net *net)
                        break;
                }
                /* Skip header. TODO: support TSO. */
-               s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
+               s = move_iovec_hdr(vq->iov, nvq->hdr, hdr_size, out);
                msg.msg_iovlen = out;
                len = iov_length(vq->iov, out);
                /* Sanity check */
                if (!len) {
                        vq_err(vq, "Unexpected header len for TX: "
                               "%zd expected %zd\n",
-                              iov_length(vq->hdr, s), hdr_size);
+                              iov_length(nvq->hdr, s), hdr_size);
                        break;
                }
                zcopy_used = zcopy && (len >= VHOST_GOODCOPY_LEN ||
-                                      vq->upend_idx != vq->done_idx);
+                                      nvq->upend_idx != nvq->done_idx);
 
                /* use msg_control to pass vhost zerocopy ubuf info to skb */
                if (zcopy_used) {
-                       vq->heads[vq->upend_idx].id = head;
+                       vq->heads[nvq->upend_idx].id = head;
                        if (!vhost_net_tx_select_zcopy(net) ||
                            len < VHOST_GOODCOPY_LEN) {
                                /* copy don't need to wait for DMA done */
-                               vq->heads[vq->upend_idx].len =
+                               vq->heads[nvq->upend_idx].len =
                                                        VHOST_DMA_DONE_LEN;
                                msg.msg_control = NULL;
                                msg.msg_controllen = 0;
                                ubufs = NULL;
                        } else {
                                struct ubuf_info *ubuf;
-                               ubuf = vq->ubuf_info + vq->upend_idx;
+                               ubuf = nvq->ubuf_info + nvq->upend_idx;
 
-                               vq->heads[vq->upend_idx].len =
+                               vq->heads[nvq->upend_idx].len =
                                        VHOST_DMA_IN_PROGRESS;
                                ubuf->callback = vhost_zerocopy_callback;
-                               ubuf->ctx = vq->ubufs;
-                               ubuf->desc = vq->upend_idx;
+                               ubuf->ctx = nvq->ubufs;
+                               ubuf->desc = nvq->upend_idx;
                                msg.msg_control = ubuf;
                                msg.msg_controllen = sizeof(ubuf);
-                               ubufs = vq->ubufs;
+                               ubufs = nvq->ubufs;
                                kref_get(&ubufs->kref);
                        }
-                       vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV;
+                       nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
                }
                /* TODO: Check specific error and bomb out unless ENOBUFS? */
                err = sock->ops->sendmsg(NULL, sock, &msg, len);
@@ -360,12 +423,10 @@ static void handle_tx(struct vhost_net *net)
                        if (zcopy_used) {
                                if (ubufs)
                                        vhost_ubuf_put(ubufs);
-                               vq->upend_idx = ((unsigned)vq->upend_idx - 1) %
-                                       UIO_MAXIOV;
+                               nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
+                                       UIO_MAXIOV;
                        }
                        vhost_discard_vq_desc(vq, 1);
-                       if (err == -EAGAIN || err == -ENOBUFS)
-                               tx_poll_start(net, sock);
                        break;
                }
                if (err != len)
@@ -470,7 +531,8 @@ err:
  * read-size critical section for our kind of RCU. */
 static void handle_rx(struct vhost_net *net)
 {
-       struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
+       struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
+       struct vhost_virtqueue *vq = &nvq->vq;
        unsigned uninitialized_var(in), log;
        struct vhost_log *vq_log;
        struct msghdr msg = {
@@ -498,8 +560,8 @@ static void handle_rx(struct vhost_net *net)
 
        mutex_lock(&vq->mutex);
        vhost_disable_notify(&net->dev, vq);
-       vhost_hlen = vq->vhost_hlen;
-       sock_hlen = vq->sock_hlen;
+       vhost_hlen = nvq->vhost_hlen;
+       sock_hlen = nvq->sock_hlen;
 
        vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
                vq->log : NULL;
@@ -529,11 +591,11 @@ static void handle_rx(struct vhost_net *net)
                /* We don't need to be notified again. */
                if (unlikely((vhost_hlen)))
                        /* Skip header. TODO: support TSO. */
-                       move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
+                       move_iovec_hdr(vq->iov, nvq->hdr, vhost_hlen, in);
                else
                        /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
                         * needed because recvmsg can modify msg_iov. */
-                       copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
+                       copy_iovec_hdr(vq->iov, nvq->hdr, sock_hlen, in);
                msg.msg_iovlen = in;
                err = sock->ops->recvmsg(NULL, sock, &msg,
                                         sock_len, MSG_DONTWAIT | MSG_TRUNC);
@@ -547,7 +609,7 @@ static void handle_rx(struct vhost_net *net)
                        continue;
                }
                if (unlikely(vhost_hlen) &&
-                   memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
+                   memcpy_toiovecend(nvq->hdr, (unsigned char *)&hdr, 0,
                                      vhost_hlen)) {
                        vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
                               vq->iov->iov_base);
@@ -555,7 +617,7 @@ static void handle_rx(struct vhost_net *net)
                }
                /* TODO: Should check and handle checksum. */
                if (likely(mergeable) &&
-                   memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
+                   memcpy_toiovecend(nvq->hdr, (unsigned char *)&headcount,
                                      offsetof(typeof(hdr), num_buffers),
                                      sizeof hdr.num_buffers)) {
                        vq_err(vq, "Failed num_buffers write");
@@ -612,23 +674,39 @@ static int vhost_net_open(struct inode *inode, struct file *f)
 {
        struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
        struct vhost_dev *dev;
-       int r;
+       struct vhost_virtqueue **vqs;
+       int r, i;
 
        if (!n)
                return -ENOMEM;
+       vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
+       if (!vqs) {
+               kfree(n);
+               return -ENOMEM;
+       }
 
        dev = &n->dev;
-       n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
-       n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
-       r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX);
+       vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
+       vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
+       n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
+       n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
+       for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
+               n->vqs[i].ubufs = NULL;
+               n->vqs[i].ubuf_info = NULL;
+               n->vqs[i].upend_idx = 0;
+               n->vqs[i].done_idx = 0;
+               n->vqs[i].vhost_hlen = 0;
+               n->vqs[i].sock_hlen = 0;
+       }
+       r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
        if (r < 0) {
                kfree(n);
+               kfree(vqs);
                return r;
        }
 
        vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
        vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
-       n->tx_poll_state = VHOST_NET_POLL_DISABLED;
 
        f->private_data = n;
 
@@ -638,32 +716,28 @@ static int vhost_net_open(struct inode *inode, struct file *f)
 static void vhost_net_disable_vq(struct vhost_net *n,
                                 struct vhost_virtqueue *vq)
 {
+       struct vhost_net_virtqueue *nvq =
+               container_of(vq, struct vhost_net_virtqueue, vq);
+       struct vhost_poll *poll = n->poll + (nvq - n->vqs);
        if (!vq->private_data)
                return;
-       if (vq == n->vqs + VHOST_NET_VQ_TX) {
-               tx_poll_stop(n);
-               n->tx_poll_state = VHOST_NET_POLL_DISABLED;
-       } else
-               vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
+       vhost_poll_stop(poll);
 }
 
 static int vhost_net_enable_vq(struct vhost_net *n,
                                struct vhost_virtqueue *vq)
 {
+       struct vhost_net_virtqueue *nvq =
+               container_of(vq, struct vhost_net_virtqueue, vq);
+       struct vhost_poll *poll = n->poll + (nvq - n->vqs);
        struct socket *sock;
-       int ret;
 
        sock = rcu_dereference_protected(vq->private_data,
                                         lockdep_is_held(&vq->mutex));
        if (!sock)
                return 0;
-       if (vq == n->vqs + VHOST_NET_VQ_TX) {
-               n->tx_poll_state = VHOST_NET_POLL_STOPPED;
-               ret = tx_poll_start(n, sock);
-       } else
-               ret = vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
 
-       return ret;
+       return vhost_poll_start(poll, sock->file);
 }
 
 static struct socket *vhost_net_stop_vq(struct vhost_net *n,
@@ -683,30 +757,30 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
 static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
                           struct socket **rx_sock)
 {
-       *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
-       *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
+       *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
+       *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
 }
 
 static void vhost_net_flush_vq(struct vhost_net *n, int index)
 {
        vhost_poll_flush(n->poll + index);
-       vhost_poll_flush(&n->dev.vqs[index].poll);
+       vhost_poll_flush(&n->vqs[index].vq.poll);
 }
 
 static void vhost_net_flush(struct vhost_net *n)
 {
        vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
        vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
-       if (n->dev.vqs[VHOST_NET_VQ_TX].ubufs) {
-               mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+       if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
+               mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
                n->tx_flush = true;
-               mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+               mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
                /* Wait for all lower device DMAs done. */
-               vhost_ubuf_put_and_wait(n->dev.vqs[VHOST_NET_VQ_TX].ubufs);
-               mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+               vhost_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
+               mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
                n->tx_flush = false;
-               kref_init(&n->dev.vqs[VHOST_NET_VQ_TX].ubufs->kref);
-               mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+               kref_init(&n->vqs[VHOST_NET_VQ_TX].ubufs->kref);
+               mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
        }
 }
 
@@ -720,6 +794,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
        vhost_net_flush(n);
        vhost_dev_stop(&n->dev);
        vhost_dev_cleanup(&n->dev, false);
+       vhost_net_vq_reset(n);
        if (tx_sock)
                fput(tx_sock->file);
        if (rx_sock)
@@ -727,6 +802,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
        /* We do an extra flush before freeing memory,
         * since jobs can re-queue themselves. */
        vhost_net_flush(n);
+       kfree(n->dev.vqs);
        kfree(n);
        return 0;
 }
@@ -800,6 +876,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 {
        struct socket *sock, *oldsock;
        struct vhost_virtqueue *vq;
+       struct vhost_net_virtqueue *nvq;
        struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
        int r;
 
@@ -812,7 +889,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
                r = -ENOBUFS;
                goto err;
        }
-       vq = n->vqs + index;
+       vq = &n->vqs[index].vq;
+       nvq = &n->vqs[index];
        mutex_lock(&vq->mutex);
 
        /* Verify that ring has been setup correctly. */
@@ -845,8 +923,8 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
                if (r)
                        goto err_used;
 
-               oldubufs = vq->ubufs;
-               vq->ubufs = ubufs;
+               oldubufs = nvq->ubufs;
+               nvq->ubufs = ubufs;
 
                n->tx_packets = 0;
                n->tx_zcopy_err = 0;
@@ -897,6 +975,7 @@ static long vhost_net_reset_owner(struct vhost_net *n)
        vhost_net_stop(n, &tx_sock, &rx_sock);
        vhost_net_flush(n);
        err = vhost_dev_reset_owner(&n->dev);
+       vhost_net_vq_reset(n);
 done:
        mutex_unlock(&n->dev.mutex);
        if (tx_sock)
@@ -932,10 +1011,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
        n->dev.acked_features = features;
        smp_wmb();
        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
-               mutex_lock(&n->vqs[i].mutex);
+               mutex_lock(&n->vqs[i].vq.mutex);
                n->vqs[i].vhost_hlen = vhost_hlen;
                n->vqs[i].sock_hlen = sock_hlen;
-               mutex_unlock(&n->vqs[i].mutex);
+               mutex_unlock(&n->vqs[i].vq.mutex);
        }
        vhost_net_flush(n);
        mutex_unlock(&n->dev.mutex);
@@ -972,11 +1051,17 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
                return vhost_net_reset_owner(n);
        default:
                mutex_lock(&n->dev.mutex);
+               if (ioctl == VHOST_SET_OWNER) {
+                       r = vhost_net_set_ubuf_info(n);
+                       if (r)
+                               goto out;
+               }
                r = vhost_dev_ioctl(&n->dev, ioctl, argp);
                if (r == -ENOIOCTLCMD)
                        r = vhost_vring_ioctl(&n->dev, ioctl, argp);
                else
                        vhost_net_flush(n);
+out:
                mutex_unlock(&n->dev.mutex);
                return r;
        }