diff --git a/crypto/chacha20poly1305.c b/crypto/chacha20poly1305.c
index 4d6f51bcdfabe903c804845ecaa6182f21a3c702..af8afe5c06ea9a827df39c65b0bda11176943043 100644
--- a/crypto/chacha20poly1305.c
+++ b/crypto/chacha20poly1305.c
@@ -67,6 +67,8 @@ struct chachapoly_req_ctx {
 	unsigned int cryptlen;
 	/* Actual AD, excluding IV */
 	unsigned int assoclen;
+	/* request flags, with MAY_SLEEP cleared if needed */
+	u32 flags;
 	union {
 		struct poly_req poly;
 		struct chacha_req chacha;
@@ -76,8 +78,12 @@ struct chachapoly_req_ctx {
 static inline void async_done_continue(struct aead_request *req, int err,
 				       int (*cont)(struct aead_request *))
 {
-	if (!err)
+	if (!err) {
+		struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
+
+		rctx->flags &= ~CRYPTO_TFM_REQ_MAY_SLEEP;
 		err = cont(req);
+	}
 
 	if (err != -EINPROGRESS && err != -EBUSY)
 		aead_request_complete(req, err);
@@ -144,7 +150,7 @@ static int chacha_decrypt(struct aead_request *req)
 		dst = scatterwalk_ffwd(rctx->dst, req->dst, req->assoclen);
 	}
 
-	skcipher_request_set_callback(&creq->req, aead_request_flags(req),
+	skcipher_request_set_callback(&creq->req, rctx->flags,
 				      chacha_decrypt_done, req);
 	skcipher_request_set_tfm(&creq->req, ctx->chacha);
 	skcipher_request_set_crypt(&creq->req, src, dst,
@@ -188,7 +194,7 @@ static int poly_tail(struct aead_request *req)
 	memcpy(&preq->tail.cryptlen, &len, sizeof(len));
 	sg_set_buf(preq->src, &preq->tail, sizeof(preq->tail));
 
-	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+	ahash_request_set_callback(&preq->req, rctx->flags,
 				   poly_tail_done, req);
 	ahash_request_set_tfm(&preq->req, ctx->poly);
 	ahash_request_set_crypt(&preq->req, preq->src,
@@ -219,7 +225,7 @@ static int poly_cipherpad(struct aead_request *req)
 	sg_init_table(preq->src, 1);
 	sg_set_buf(preq->src, &preq->pad, padlen);
 
-	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+	ahash_request_set_callback(&preq->req, rctx->flags,
 				   poly_cipherpad_done, req);
 	ahash_request_set_tfm(&preq->req, ctx->poly);
 	ahash_request_set_crypt(&preq->req, preq->src, NULL, padlen);
@@ -250,7 +256,7 @@ static int poly_cipher(struct aead_request *req)
 	sg_init_table(rctx->src, 2);
 	crypt = scatterwalk_ffwd(rctx->src, crypt, req->assoclen);
 
-	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+	ahash_request_set_callback(&preq->req, rctx->flags,
 				   poly_cipher_done, req);
 	ahash_request_set_tfm(&preq->req, ctx->poly);
 	ahash_request_set_crypt(&preq->req, crypt, NULL, rctx->cryptlen);
@@ -280,7 +286,7 @@ static int poly_adpad(struct aead_request *req)
 	sg_init_table(preq->src, 1);
 	sg_set_buf(preq->src, preq->pad, padlen);
 
-	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+	ahash_request_set_callback(&preq->req, rctx->flags,
 				   poly_adpad_done, req);
 	ahash_request_set_tfm(&preq->req, ctx->poly);
 	ahash_request_set_crypt(&preq->req, preq->src, NULL, padlen);
@@ -304,7 +310,7 @@ static int poly_ad(struct aead_request *req)
 	struct poly_req *preq = &rctx->u.poly;
 	int err;
 
-	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+	ahash_request_set_callback(&preq->req, rctx->flags,
 				   poly_ad_done, req);
 	ahash_request_set_tfm(&preq->req, ctx->poly);
 	ahash_request_set_crypt(&preq->req, req->src, NULL, rctx->assoclen);
@@ -331,7 +337,7 @@ static int poly_setkey(struct aead_request *req)
 	sg_init_table(preq->src, 1);
 	sg_set_buf(preq->src, rctx->key, sizeof(rctx->key));
 
-	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+	ahash_request_set_callback(&preq->req, rctx->flags,
 				   poly_setkey_done, req);
 	ahash_request_set_tfm(&preq->req, ctx->poly);
 	ahash_request_set_crypt(&preq->req, preq->src, NULL, sizeof(rctx->key));
@@ -355,7 +361,7 @@ static int poly_init(struct aead_request *req)
 	struct poly_req *preq = &rctx->u.poly;
 	int err;
 
-	ahash_request_set_callback(&preq->req, aead_request_flags(req),
+	ahash_request_set_callback(&preq->req, rctx->flags,
 				   poly_init_done, req);
 	ahash_request_set_tfm(&preq->req, ctx->poly);
 
@@ -393,7 +399,7 @@ static int poly_genkey(struct aead_request *req)
 
 	chacha_iv(creq->iv, req, 0);
 
-	skcipher_request_set_callback(&creq->req, aead_request_flags(req),
+	skcipher_request_set_callback(&creq->req, rctx->flags,
 				      poly_genkey_done, req);
 	skcipher_request_set_tfm(&creq->req, ctx->chacha);
 	skcipher_request_set_crypt(&creq->req, creq->src, creq->src,
@@ -433,7 +439,7 @@ static int chacha_encrypt(struct aead_request *req)
 		dst = scatterwalk_ffwd(rctx->dst, req->dst, req->assoclen);
 	}
 
-	skcipher_request_set_callback(&creq->req, aead_request_flags(req),
+	skcipher_request_set_callback(&creq->req, rctx->flags,
 				      chacha_encrypt_done, req);
 	skcipher_request_set_tfm(&creq->req, ctx->chacha);
 	skcipher_request_set_crypt(&creq->req, src, dst,
@@ -451,6 +457,7 @@ static int chachapoly_encrypt(struct aead_request *req)
 	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
 
 	rctx->cryptlen = req->cryptlen;
+	rctx->flags = aead_request_flags(req);
 
 	/* encrypt call chain:
 	 * - chacha_encrypt/done()
@@ -472,6 +479,7 @@ static int chachapoly_decrypt(struct aead_request *req)
 	struct chachapoly_req_ctx *rctx = aead_request_ctx(req);
 
 	rctx->cryptlen = req->cryptlen - POLY1305_DIGEST_SIZE;
+	rctx->flags = aead_request_flags(req);
 
 	/* decrypt call chain:
 	 * - poly_genkey/done()