]> Pileus Git - ~andy/linux/blobdiff - drivers/vhost/vhost.c
Merge branch 'vfs-scale-working' of git://git.kernel.org/pub/scm/linux/kernel/git...
[~andy/linux] / drivers / vhost / vhost.c
index 159c77a5746fecfd6c6eeb687cf11e747e194fba..38244f59cdd91e76acc63f7fb4d7b66a5a730fbf 100644 (file)
@@ -15,6 +15,7 @@
 #include <linux/vhost.h>
 #include <linux/virtio_net.h>
 #include <linux/mm.h>
+#include <linux/mmu_context.h>
 #include <linux/miscdevice.h>
 #include <linux/mutex.h>
 #include <linux/rcupdate.h>
@@ -29,8 +30,6 @@
 #include <linux/if_packet.h>
 #include <linux/if_arp.h>
 
-#include <net/sock.h>
-
 #include "vhost.h"
 
 enum {
@@ -157,7 +156,6 @@ static void vhost_vq_reset(struct vhost_dev *dev,
        vq->avail_idx = 0;
        vq->last_used_idx = 0;
        vq->used_flags = 0;
-       vq->used_flags = 0;
        vq->log_used = false;
        vq->log_addr = -1ull;
        vq->vhost_hlen = 0;
@@ -178,6 +176,8 @@ static int vhost_worker(void *data)
        struct vhost_work *work = NULL;
        unsigned uninitialized_var(seq);
 
+       use_mm(dev->mm);
+
        for (;;) {
                /* mb paired w/ kthread_stop */
                set_current_state(TASK_INTERRUPTIBLE);
@@ -192,7 +192,7 @@ static int vhost_worker(void *data)
                if (kthread_should_stop()) {
                        spin_unlock_irq(&dev->work_lock);
                        __set_current_state(TASK_RUNNING);
-                       return 0;
+                       break;
                }
                if (!list_empty(&dev->work_list)) {
                        work = list_first_entry(&dev->work_list,
@@ -210,6 +210,8 @@ static int vhost_worker(void *data)
                        schedule();
 
        }
+       unuse_mm(dev->mm);
+       return 0;
 }
 
 /* Helper to allocate iovec buffers for all vqs. */
@@ -402,15 +404,14 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
        kfree(rcu_dereference_protected(dev->memory,
                                        lockdep_is_held(&dev->mutex)));
        RCU_INIT_POINTER(dev->memory, NULL);
-       if (dev->mm)
-               mmput(dev->mm);
-       dev->mm = NULL;
-
        WARN_ON(!list_empty(&dev->work_list));
        if (dev->worker) {
                kthread_stop(dev->worker);
                dev->worker = NULL;
        }
+       if (dev->mm)
+               mmput(dev->mm);
+       dev->mm = NULL;
 }
 
 static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
@@ -881,15 +882,15 @@ static int set_bit_to_user(int nr, void __user *addr)
 static int log_write(void __user *log_base,
                     u64 write_address, u64 write_length)
 {
+       u64 write_page = write_address / VHOST_PAGE_SIZE;
        int r;
        if (!write_length)
                return 0;
        write_length += write_address % VHOST_PAGE_SIZE;
-       write_address /= VHOST_PAGE_SIZE;
        for (;;) {
                u64 base = (u64)(unsigned long)log_base;
-               u64 log = base + write_address / 8;
-               int bit = write_address % 8;
+               u64 log = base + write_page / 8;
+               int bit = write_page % 8;
                if ((u64)(unsigned long)log != log)
                        return -EFAULT;
                r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
@@ -898,7 +899,7 @@ static int log_write(void __user *log_base,
                if (write_length <= VHOST_PAGE_SIZE)
                        break;
                write_length -= VHOST_PAGE_SIZE;
-               write_address += 1;
+               write_page += 1;
        }
        return r;
 }
@@ -1093,7 +1094,7 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
 
        /* Check it isn't doing very strange things with descriptor numbers. */
        last_avail_idx = vq->last_avail_idx;
-       if (unlikely(get_user(vq->avail_idx, &vq->avail->idx))) {
+       if (unlikely(__get_user(vq->avail_idx, &vq->avail->idx))) {
                vq_err(vq, "Failed to access avail idx at %p\n",
                       &vq->avail->idx);
                return -EFAULT;
@@ -1114,8 +1115,8 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
 
        /* Grab the next descriptor number they're advertising, and increment
         * the index we've seen. */
-       if (unlikely(get_user(head,
-                             &vq->avail->ring[last_avail_idx % vq->num]))) {
+       if (unlikely(__get_user(head,
+                               &vq->avail->ring[last_avail_idx % vq->num]))) {
                vq_err(vq, "Failed to read head: idx %d address %p\n",
                       last_avail_idx,
                       &vq->avail->ring[last_avail_idx % vq->num]);
@@ -1214,17 +1215,17 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
        /* The virtqueue contains a ring of used buffers.  Get a pointer to the
         * next entry in that used ring. */
        used = &vq->used->ring[vq->last_used_idx % vq->num];
-       if (put_user(head, &used->id)) {
+       if (__put_user(head, &used->id)) {
                vq_err(vq, "Failed to write used id");
                return -EFAULT;
        }
-       if (put_user(len, &used->len)) {
+       if (__put_user(len, &used->len)) {
                vq_err(vq, "Failed to write used len");
                return -EFAULT;
        }
        /* Make sure buffer is written before we update index. */
        smp_wmb();
-       if (put_user(vq->last_used_idx + 1, &vq->used->idx)) {
+       if (__put_user(vq->last_used_idx + 1, &vq->used->idx)) {
                vq_err(vq, "Failed to increment used idx");
                return -EFAULT;
        }
@@ -1256,7 +1257,7 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
 
        start = vq->last_used_idx % vq->num;
        used = vq->used->ring + start;
-       if (copy_to_user(used, heads, count * sizeof *used)) {
+       if (__copy_to_user(used, heads, count * sizeof *used)) {
                vq_err(vq, "Failed to write used");
                return -EFAULT;
        }
@@ -1317,7 +1318,7 @@ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
         * interrupts. */
        smp_mb();
 
-       if (get_user(flags, &vq->avail->flags)) {
+       if (__get_user(flags, &vq->avail->flags)) {
                vq_err(vq, "Failed to get flags");
                return;
        }
@@ -1368,7 +1369,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)
        /* They could have slipped one in as we were doing that: make
         * sure it's written, then check again. */
        smp_mb();
-       r = get_user(avail_idx, &vq->avail->idx);
+       r = __get_user(avail_idx, &vq->avail->idx);
        if (r) {
                vq_err(vq, "Failed to check avail idx at %p: %d\n",
                       &vq->avail->idx, r);