diff --git a/drivers/net/dsa/b53/b53_common.c b/drivers/net/dsa/b53/b53_common.c
index bb210b12ad1b90f9485d1bec96e15973a9feb4fd..31afc4d4b68b959e7457a615bb98cfb6e9e630d3 100644
--- a/drivers/net/dsa/b53/b53_common.c
+++ b/drivers/net/dsa/b53/b53_common.c
@@ -1790,14 +1790,15 @@ struct b53_device *b53_switch_alloc(struct device *base,
 	struct dsa_switch *ds;
 	struct b53_device *dev;
 
-	ds = devm_kzalloc(base, sizeof(*ds) + sizeof(*dev), GFP_KERNEL);
+	ds = dsa_switch_alloc(base, DSA_MAX_PORTS);
 	if (!ds)
 		return NULL;
 
-	dev = (struct b53_device *)(ds + 1);
+	dev = devm_kzalloc(base, sizeof(*dev), GFP_KERNEL);
+	if (!dev)
+		return NULL;
 
 	ds->priv = dev;
-	ds->dev = base;
 	dev->dev = base;
 
 	dev->ds = ds;
diff --git a/drivers/net/dsa/mv88e6xxx/chip.c b/drivers/net/dsa/mv88e6xxx/chip.c
index 921e5335178663dd3e346033e410d0b85daef07a..cb7b24748336d531827c076cdf019cd61bb72a4f 100644
--- a/drivers/net/dsa/mv88e6xxx/chip.c
+++ b/drivers/net/dsa/mv88e6xxx/chip.c
@@ -4361,11 +4361,10 @@ static int mv88e6xxx_register_switch(struct mv88e6xxx_chip *chip)
 	struct device *dev = chip->dev;
 	struct dsa_switch *ds;
 
-	ds = devm_kzalloc(dev, sizeof(*ds), GFP_KERNEL);
+	ds = dsa_switch_alloc(dev, DSA_MAX_PORTS);
 	if (!ds)
 		return -ENOMEM;
 
-	ds->dev = dev;
 	ds->priv = chip;
 	ds->ops = &mv88e6xxx_switch_ops;
 
diff --git a/drivers/net/dsa/qca8k.c b/drivers/net/dsa/qca8k.c
index c084aa484d2bd100312cd6d2302f86faee5e28b8..f67c6a3cebff82b8e20b82846f77a30301237f9c 100644
--- a/drivers/net/dsa/qca8k.c
+++ b/drivers/net/dsa/qca8k.c
@@ -954,12 +954,11 @@ qca8k_sw_probe(struct mdio_device *mdiodev)
 	if (id != QCA8K_ID_QCA8337)
 		return -ENODEV;
 
-	priv->ds = devm_kzalloc(&mdiodev->dev, sizeof(*priv->ds), GFP_KERNEL);
+	priv->ds = dsa_switch_alloc(&mdiodev->dev, DSA_MAX_PORTS);
 	if (!priv->ds)
 		return -ENOMEM;
 
 	priv->ds->priv = priv;
-	priv->ds->dev = &mdiodev->dev;
 	priv->ds->ops = &qca8k_switch_ops;
 	mutex_init(&priv->reg_mutex);
 	dev_set_drvdata(&mdiodev->dev, priv);
diff --git a/include/net/dsa.h b/include/net/dsa.h
index 92fd795e9573e36c241b688dcbd4cdc20954c98a..24e1d935ae681171df38c5b0dfb0f14b513ba180 100644
--- a/include/net/dsa.h
+++ b/include/net/dsa.h
@@ -190,8 +190,11 @@ struct dsa_switch {
 	u32			cpu_port_mask;
 	u32			enabled_port_mask;
 	u32			phys_mii_mask;
-	struct dsa_port		ports[DSA_MAX_PORTS];
 	struct mii_bus		*slave_mii_bus;
+
+	/* Dynamically allocated ports, keep last */
+	size_t num_ports;
+	struct dsa_port ports[];
 };
 
 static inline bool dsa_is_cpu_port(struct dsa_switch *ds, int p)
@@ -386,6 +389,7 @@ static inline bool dsa_uses_tagged_protocol(struct dsa_switch_tree *dst)
 	return dst->rcv != NULL;
 }
 
+struct dsa_switch *dsa_switch_alloc(struct device *dev, size_t n);
 void dsa_unregister_switch(struct dsa_switch *ds);
 int dsa_register_switch(struct dsa_switch *ds, struct device *dev);
 #ifdef CONFIG_PM_SLEEP
diff --git a/net/dsa/dsa.c b/net/dsa/dsa.c
index 07e863369e04826d9386a1320bf42bdfd7a8b109..de3ffb421ee4be59c74b6bceb2c27674cc03dae3 100644
--- a/net/dsa/dsa.c
+++ b/net/dsa/dsa.c
@@ -347,8 +347,8 @@ dsa_switch_setup(struct dsa_switch_tree *dst, int index,
 	/*
 	 * Allocate and initialise switch state.
 	 */
-	ds = devm_kzalloc(parent, sizeof(*ds), GFP_KERNEL);
-	if (ds == NULL)
+	ds = dsa_switch_alloc(parent, DSA_MAX_PORTS);
+	if (!ds)
 		return ERR_PTR(-ENOMEM);
 
 	ds->dst = dst;
@@ -356,7 +356,6 @@ dsa_switch_setup(struct dsa_switch_tree *dst, int index,
 	ds->cd = cd;
 	ds->ops = ops;
 	ds->priv = priv;
-	ds->dev = parent;
 
 	ret = dsa_switch_setup_one(ds, parent);
 	if (ret)
diff --git a/net/dsa/dsa2.c b/net/dsa/dsa2.c
index 75f5d1f8554b82bcc5fdce8d2c0c17eb7bc06749..4b3a44bec5c89b33fe4382c48417c68917e63dd9 100644
--- a/net/dsa/dsa2.c
+++ b/net/dsa/dsa2.c
@@ -666,6 +666,22 @@ static int _dsa_register_switch(struct dsa_switch *ds, struct device *dev)
 	return err;
 }
 
+struct dsa_switch *dsa_switch_alloc(struct device *dev, size_t n)
+{
+	size_t size = sizeof(struct dsa_switch) + n * sizeof(struct dsa_port);
+	struct dsa_switch *ds;
+
+	ds = devm_kzalloc(dev, size, GFP_KERNEL);
+	if (!ds)
+		return NULL;
+
+	ds->dev = dev;
+	ds->num_ports = n;
+
+	return ds;
+}
+EXPORT_SYMBOL_GPL(dsa_switch_alloc);
+
 int dsa_register_switch(struct dsa_switch *ds, struct device *dev)
 {
 	int err;