diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
index 288137cb08711af97c342e9bf9f839e575418f84..396490cf7316b6179a2ab3ccf83a2eb302841c9f 100644
--- a/arch/arm64/net/bpf_jit_comp.c
+++ b/arch/arm64/net/bpf_jit_comp.c
@@ -99,6 +99,20 @@ static inline void emit_a64_mov_i64(const int reg, const u64 val,
 	}
 }
 
+static inline void emit_addr_mov_i64(const int reg, const u64 val,
+				     struct jit_ctx *ctx)
+{
+	u64 tmp = val;
+	int shift = 0;
+
+	emit(A64_MOVZ(1, reg, tmp & 0xffff, shift), ctx);
+	for (;shift < 48;) {
+		tmp >>= 16;
+		shift += 16;
+		emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
+	}
+}
+
 static inline void emit_a64_mov_i(const int is64, const int reg,
 				  const s32 val, struct jit_ctx *ctx)
 {
@@ -603,7 +617,10 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
 		const u8 r0 = bpf2a64[BPF_REG_0];
 		const u64 func = (u64)__bpf_call_base + imm;
 
-		emit_a64_mov_i64(tmp, func, ctx);
+		if (ctx->prog->is_func)
+			emit_addr_mov_i64(tmp, func, ctx);
+		else
+			emit_a64_mov_i64(tmp, func, ctx);
 		emit(A64_BLR(tmp), ctx);
 		emit(A64_MOV(1, r0, A64_R(0)), ctx);
 		break;
@@ -835,11 +852,19 @@ static inline void bpf_flush_icache(void *start, void *end)
 	flush_icache_range((unsigned long)start, (unsigned long)end);
 }
 
+struct arm64_jit_data {
+	struct bpf_binary_header *header;
+	u8 *image;
+	struct jit_ctx ctx;
+};
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
 	struct bpf_prog *tmp, *orig_prog = prog;
 	struct bpf_binary_header *header;
+	struct arm64_jit_data *jit_data;
 	bool tmp_blinded = false;
+	bool extra_pass = false;
 	struct jit_ctx ctx;
 	int image_size;
 	u8 *image_ptr;
@@ -858,13 +883,29 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 		prog = tmp;
 	}
 
+	jit_data = prog->aux->jit_data;
+	if (!jit_data) {
+		jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
+		if (!jit_data) {
+			prog = orig_prog;
+			goto out;
+		}
+		prog->aux->jit_data = jit_data;
+	}
+	if (jit_data->ctx.offset) {
+		ctx = jit_data->ctx;
+		image_ptr = jit_data->image;
+		header = jit_data->header;
+		extra_pass = true;
+		goto skip_init_ctx;
+	}
 	memset(&ctx, 0, sizeof(ctx));
 	ctx.prog = prog;
 
 	ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
 	if (ctx.offset == NULL) {
 		prog = orig_prog;
-		goto out;
+		goto out_off;
 	}
 
 	/* 1. Initial fake pass to compute ctx->idx. */
@@ -895,6 +936,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 	/* 2. Now, the actual pass. */
 
 	ctx.image = (__le32 *)image_ptr;
+skip_init_ctx:
 	ctx.idx = 0;
 
 	build_prologue(&ctx);
@@ -920,13 +962,31 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 
 	bpf_flush_icache(header, ctx.image + ctx.idx);
 
-	bpf_jit_binary_lock_ro(header);
+	if (!prog->is_func || extra_pass) {
+		if (extra_pass && ctx.idx != jit_data->ctx.idx) {
+			pr_err_once("multi-func JIT bug %d != %d\n",
+				    ctx.idx, jit_data->ctx.idx);
+			bpf_jit_binary_free(header);
+			prog->bpf_func = NULL;
+			prog->jited = 0;
+			goto out_off;
+		}
+		bpf_jit_binary_lock_ro(header);
+	} else {
+		jit_data->ctx = ctx;
+		jit_data->image = image_ptr;
+		jit_data->header = header;
+	}
 	prog->bpf_func = (void *)ctx.image;
 	prog->jited = 1;
 	prog->jited_len = image_size;
 
+	if (!prog->is_func || extra_pass) {
 out_off:
-	kfree(ctx.offset);
+		kfree(ctx.offset);
+		kfree(jit_data);
+		prog->aux->jit_data = NULL;
+	}
 out:
 	if (tmp_blinded)
 		bpf_jit_prog_release_other(prog, prog == orig_prog ?