diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 669fef1e2bb6094816fad5ed322d947994c86693..73dd16d0f587c294eea8bce9781ab5a3d92211b6 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -131,6 +131,19 @@ static void vhost_reset_is_le(struct vhost_virtqueue *vq)
 	vq->is_le = virtio_legacy_is_little_endian();
 }
 
+struct vhost_flush_struct {
+	struct vhost_work work;
+	struct completion wait_event;
+};
+
+static void vhost_flush_work(struct vhost_work *work)
+{
+	struct vhost_flush_struct *s;
+
+	s = container_of(work, struct vhost_flush_struct, work);
+	complete(&s->wait_event);
+}
+
 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
 			    poll_table *pt)
 {
@@ -158,8 +171,6 @@ void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
 	INIT_LIST_HEAD(&work->node);
 	work->fn = fn;
 	init_waitqueue_head(&work->done);
-	work->flushing = 0;
-	work->queue_seq = work->done_seq = 0;
 }
 EXPORT_SYMBOL_GPL(vhost_work_init);
 
@@ -211,31 +222,17 @@ void vhost_poll_stop(struct vhost_poll *poll)
 }
 EXPORT_SYMBOL_GPL(vhost_poll_stop);
 
-static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work,
-				unsigned seq)
-{
-	int left;
-
-	spin_lock_irq(&dev->work_lock);
-	left = seq - work->done_seq;
-	spin_unlock_irq(&dev->work_lock);
-	return left <= 0;
-}
-
 void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
 {
-	unsigned seq;
-	int flushing;
+	struct vhost_flush_struct flush;
+
+	if (dev->worker) {
+		init_completion(&flush.wait_event);
+		vhost_work_init(&flush.work, vhost_flush_work);
 
-	spin_lock_irq(&dev->work_lock);
-	seq = work->queue_seq;
-	work->flushing++;
-	spin_unlock_irq(&dev->work_lock);
-	wait_event(work->done, vhost_work_seq_done(dev, work, seq));
-	spin_lock_irq(&dev->work_lock);
-	flushing = --work->flushing;
-	spin_unlock_irq(&dev->work_lock);
-	BUG_ON(flushing < 0);
+		vhost_work_queue(dev, &flush.work);
+		wait_for_completion(&flush.wait_event);
+	}
 }
 EXPORT_SYMBOL_GPL(vhost_work_flush);
 
@@ -254,7 +251,6 @@ void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
 	spin_lock_irqsave(&dev->work_lock, flags);
 	if (list_empty(&work->node)) {
 		list_add_tail(&work->node, &dev->work_list);
-		work->queue_seq++;
 		spin_unlock_irqrestore(&dev->work_lock, flags);
 		wake_up_process(dev->worker);
 	} else {
@@ -310,7 +306,6 @@ static int vhost_worker(void *data)
 {
 	struct vhost_dev *dev = data;
 	struct vhost_work *work = NULL;
-	unsigned uninitialized_var(seq);
 	mm_segment_t oldfs = get_fs();
 
 	set_fs(USER_DS);
@@ -321,11 +316,6 @@ static int vhost_worker(void *data)
 		set_current_state(TASK_INTERRUPTIBLE);
 
 		spin_lock_irq(&dev->work_lock);
-		if (work) {
-			work->done_seq = seq;
-			if (work->flushing)
-				wake_up_all(&work->done);
-		}
 
 		if (kthread_should_stop()) {
 			spin_unlock_irq(&dev->work_lock);
@@ -336,7 +326,6 @@ static int vhost_worker(void *data)
 			work = list_first_entry(&dev->work_list,
 						struct vhost_work, node);
 			list_del_init(&work->node);
-			seq = work->queue_seq;
 		} else
 			work = NULL;
 		spin_unlock_irq(&dev->work_lock);