]> Pileus Git - ~andy/linux/blobdiff - tools/lguest/lguest.c
Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[~andy/linux] / tools / lguest / lguest.c
index fd2f9221b24120e25d32c2c8d8b4849d8d24569a..07a03452c227e3804a04661f26e7ca8cc73bec35 100644 (file)
@@ -179,29 +179,6 @@ static struct termios orig_term;
 #define wmb() __asm__ __volatile__("" : : : "memory")
 #define mb() __asm__ __volatile__("" : : : "memory")
 
-/*
- * Convert an iovec element to the given type.
- *
- * This is a fairly ugly trick: we need to know the size of the type and
- * alignment requirement to check the pointer is kosher.  It's also nice to
- * have the name of the type in case we report failure.
- *
- * Typing those three things all the time is cumbersome and error prone, so we
- * have a macro which sets them all up and passes to the real function.
- */
-#define convert(iov, type) \
-       ((type *)_convert((iov), sizeof(type), __alignof__(type), #type))
-
-static void *_convert(struct iovec *iov, size_t size, size_t align,
-                     const char *name)
-{
-       if (iov->iov_len != size)
-               errx(1, "Bad iovec size %zu for %s", iov->iov_len, name);
-       if ((unsigned long)iov->iov_base % align != 0)
-               errx(1, "Bad alignment %p for %s", iov->iov_base, name);
-       return iov->iov_base;
-}
-
 /* Wrapper for the last available index.  Makes it easier to change. */
 #define lg_last_avail(vq)      ((vq)->last_avail_idx)
 
@@ -228,7 +205,8 @@ static bool iov_empty(const struct iovec iov[], unsigned int num_iov)
 }
 
 /* Take len bytes from the front of this iovec. */
-static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len)
+static void iov_consume(struct iovec iov[], unsigned num_iov,
+                       void *dest, unsigned len)
 {
        unsigned int i;
 
@@ -236,11 +214,16 @@ static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len)
                unsigned int used;
 
                used = iov[i].iov_len < len ? iov[i].iov_len : len;
+               if (dest) {
+                       memcpy(dest, iov[i].iov_base, used);
+                       dest += used;
+               }
                iov[i].iov_base += used;
                iov[i].iov_len -= used;
                len -= used;
        }
-       assert(len == 0);
+       if (len != 0)
+               errx(1, "iovec too short!");
 }
 
 /* The device virtqueue descriptors are followed by feature bitmasks. */
@@ -864,7 +847,7 @@ static void console_output(struct virtqueue *vq)
                        warn("Write to stdout gave %i (%d)", len, errno);
                        break;
                }
-               iov_consume(iov, out, len);
+               iov_consume(iov, out, NULL, len);
        }
 
        /*
@@ -1591,9 +1574,9 @@ static void blk_request(struct virtqueue *vq)
 {
        struct vblk_info *vblk = vq->dev->priv;
        unsigned int head, out_num, in_num, wlen;
-       int ret;
+       int ret, i;
        u8 *in;
-       struct virtio_blk_outhdr *out;
+       struct virtio_blk_outhdr out;
        struct iovec iov[vq->vring.num];
        off64_t off;
 
@@ -1603,32 +1586,36 @@ static void blk_request(struct virtqueue *vq)
         */
        head = wait_for_vq_desc(vq, iov, &out_num, &in_num);
 
-       /*
-        * Every block request should contain at least one output buffer
-        * (detailing the location on disk and the type of request) and one
-        * input buffer (to hold the result).
-        */
-       if (out_num == 0 || in_num == 0)
-               errx(1, "Bad virtblk cmd %u out=%u in=%u",
-                    head, out_num, in_num);
+       /* Copy the output header from the front of the iov (adjusts iov) */
+       iov_consume(iov, out_num, &out, sizeof(out));
+
+       /* Find and trim end of iov input array, for our status byte. */
+       in = NULL;
+       for (i = out_num + in_num - 1; i >= out_num; i--) {
+               if (iov[i].iov_len > 0) {
+                       in = iov[i].iov_base + iov[i].iov_len - 1;
+                       iov[i].iov_len--;
+                       break;
+               }
+       }
+       if (!in)
+               errx(1, "Bad virtblk cmd with no room for status");
 
-       out = convert(&iov[0], struct virtio_blk_outhdr);
-       in = convert(&iov[out_num+in_num-1], u8);
        /*
         * For historical reasons, block operations are expressed in 512 byte
         * "sectors".
         */
-       off = out->sector * 512;
+       off = out.sector * 512;
 
        /*
         * In general the virtio block driver is allowed to try SCSI commands.
         * It'd be nice if we supported eject, for example, but we don't.
         */
-       if (out->type & VIRTIO_BLK_T_SCSI_CMD) {
+       if (out.type & VIRTIO_BLK_T_SCSI_CMD) {
                fprintf(stderr, "Scsi commands unsupported\n");
                *in = VIRTIO_BLK_S_UNSUPP;
                wlen = sizeof(*in);
-       } else if (out->type & VIRTIO_BLK_T_OUT) {
+       } else if (out.type & VIRTIO_BLK_T_OUT) {
                /*
                 * Write
                 *
@@ -1636,10 +1623,10 @@ static void blk_request(struct virtqueue *vq)
                 * if they try to write past end.
                 */
                if (lseek64(vblk->fd, off, SEEK_SET) != off)
-                       err(1, "Bad seek to sector %llu", out->sector);
+                       err(1, "Bad seek to sector %llu", out.sector);
 
-               ret = writev(vblk->fd, iov+1, out_num-1);
-               verbose("WRITE to sector %llu: %i\n", out->sector, ret);
+               ret = writev(vblk->fd, iov, out_num);
+               verbose("WRITE to sector %llu: %i\n", out.sector, ret);
 
                /*
                 * Grr... Now we know how long the descriptor they sent was, we
@@ -1655,7 +1642,7 @@ static void blk_request(struct virtqueue *vq)
 
                wlen = sizeof(*in);
                *in = (ret >= 0 ? VIRTIO_BLK_S_OK : VIRTIO_BLK_S_IOERR);
-       } else if (out->type & VIRTIO_BLK_T_FLUSH) {
+       } else if (out.type & VIRTIO_BLK_T_FLUSH) {
                /* Flush */
                ret = fdatasync(vblk->fd);
                verbose("FLUSH fdatasync: %i\n", ret);
@@ -1669,10 +1656,9 @@ static void blk_request(struct virtqueue *vq)
                 * if they try to read past end.
                 */
                if (lseek64(vblk->fd, off, SEEK_SET) != off)
-                       err(1, "Bad seek to sector %llu", out->sector);
+                       err(1, "Bad seek to sector %llu", out.sector);
 
-               ret = readv(vblk->fd, iov+1, in_num-1);
-               verbose("READ from sector %llu: %i\n", out->sector, ret);
+               ret = readv(vblk->fd, iov + out_num, in_num);
                if (ret >= 0) {
                        wlen = sizeof(*in) + ret;
                        *in = VIRTIO_BLK_S_OK;
@@ -1758,7 +1744,7 @@ static void rng_input(struct virtqueue *vq)
                len = readv(rng_info->rfd, iov, in_num);
                if (len <= 0)
                        err(1, "Read from /dev/random gave %i", len);
-               iov_consume(iov, in_num, len);
+               iov_consume(iov, in_num, NULL, len);
                totlen += len;
        }