net: vlan: prepare for 802.1ad VLAN filtering offload

Change the rx_{add,kill}_vid callbacks to take a protocol argument in
preparation of 802.1ad support. The protocol argument used so far is
always htons(ETH_P_8021Q).

Signed-off-by: Patrick McHardy <kaber@trash.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/8021q/vlan_core.c b/net/8021q/vlan_core.c
index 3df29d3..04e3b95 100644
--- a/net/8021q/vlan_core.c
+++ b/net/8021q/vlan_core.c
@@ -185,35 +185,37 @@
 
 struct vlan_vid_info {
 	struct list_head list;
-	unsigned short vid;
+	__be16 proto;
+	u16 vid;
 	int refcount;
 };
 
 static struct vlan_vid_info *vlan_vid_info_get(struct vlan_info *vlan_info,
-					       unsigned short vid)
+					       __be16 proto, u16 vid)
 {
 	struct vlan_vid_info *vid_info;
 
 	list_for_each_entry(vid_info, &vlan_info->vid_list, list) {
-		if (vid_info->vid == vid)
+		if (vid_info->proto == proto && vid_info->vid == vid)
 			return vid_info;
 	}
 	return NULL;
 }
 
-static struct vlan_vid_info *vlan_vid_info_alloc(unsigned short vid)
+static struct vlan_vid_info *vlan_vid_info_alloc(__be16 proto, u16 vid)
 {
 	struct vlan_vid_info *vid_info;
 
 	vid_info = kzalloc(sizeof(struct vlan_vid_info), GFP_KERNEL);
 	if (!vid_info)
 		return NULL;
+	vid_info->proto = proto;
 	vid_info->vid = vid;
 
 	return vid_info;
 }
 
-static int __vlan_vid_add(struct vlan_info *vlan_info, unsigned short vid,
+static int __vlan_vid_add(struct vlan_info *vlan_info, __be16 proto, u16 vid,
 			  struct vlan_vid_info **pvid_info)
 {
 	struct net_device *dev = vlan_info->real_dev;
@@ -221,12 +223,13 @@
 	struct vlan_vid_info *vid_info;
 	int err;
 
-	vid_info = vlan_vid_info_alloc(vid);
+	vid_info = vlan_vid_info_alloc(proto, vid);
 	if (!vid_info)
 		return -ENOMEM;
 
-	if (dev->features & NETIF_F_HW_VLAN_CTAG_FILTER) {
-		err =  ops->ndo_vlan_rx_add_vid(dev, vid);
+	if (proto == htons(ETH_P_8021Q) &&
+	    dev->features & NETIF_F_HW_VLAN_CTAG_FILTER) {
+		err =  ops->ndo_vlan_rx_add_vid(dev, proto, vid);
 		if (err) {
 			kfree(vid_info);
 			return err;
@@ -238,7 +241,7 @@
 	return 0;
 }
 
-int vlan_vid_add(struct net_device *dev, unsigned short vid)
+int vlan_vid_add(struct net_device *dev, __be16 proto, u16 vid)
 {
 	struct vlan_info *vlan_info;
 	struct vlan_vid_info *vid_info;
@@ -254,9 +257,9 @@
 			return -ENOMEM;
 		vlan_info_created = true;
 	}
-	vid_info = vlan_vid_info_get(vlan_info, vid);
+	vid_info = vlan_vid_info_get(vlan_info, proto, vid);
 	if (!vid_info) {
-		err = __vlan_vid_add(vlan_info, vid, &vid_info);
+		err = __vlan_vid_add(vlan_info, proto, vid, &vid_info);
 		if (err)
 			goto out_free_vlan_info;
 	}
@@ -279,14 +282,16 @@
 {
 	struct net_device *dev = vlan_info->real_dev;
 	const struct net_device_ops *ops = dev->netdev_ops;
-	unsigned short vid = vid_info->vid;
+	__be16 proto = vid_info->proto;
+	u16 vid = vid_info->vid;
 	int err;
 
-	if (dev->features & NETIF_F_HW_VLAN_CTAG_FILTER) {
-		err = ops->ndo_vlan_rx_kill_vid(dev, vid);
+	if (proto == htons(ETH_P_8021Q) &&
+	    dev->features & NETIF_F_HW_VLAN_CTAG_FILTER) {
+		err = ops->ndo_vlan_rx_kill_vid(dev, proto, vid);
 		if (err) {
-			pr_warn("failed to kill vid %d for device %s\n",
-				vid, dev->name);
+			pr_warn("failed to kill vid %04x/%d for device %s\n",
+				proto, vid, dev->name);
 		}
 	}
 	list_del(&vid_info->list);
@@ -294,7 +299,7 @@
 	vlan_info->nr_vids--;
 }
 
-void vlan_vid_del(struct net_device *dev, unsigned short vid)
+void vlan_vid_del(struct net_device *dev, __be16 proto, u16 vid)
 {
 	struct vlan_info *vlan_info;
 	struct vlan_vid_info *vid_info;
@@ -305,7 +310,7 @@
 	if (!vlan_info)
 		return;
 
-	vid_info = vlan_vid_info_get(vlan_info, vid);
+	vid_info = vlan_vid_info_get(vlan_info, proto, vid);
 	if (!vid_info)
 		return;
 	vid_info->refcount--;
@@ -333,7 +338,7 @@
 		return 0;
 
 	list_for_each_entry(vid_info, &vlan_info->vid_list, list) {
-		err = vlan_vid_add(dev, vid_info->vid);
+		err = vlan_vid_add(dev, vid_info->proto, vid_info->vid);
 		if (err)
 			goto unwind;
 	}
@@ -343,7 +348,7 @@
 	list_for_each_entry_continue_reverse(vid_info,
 					     &vlan_info->vid_list,
 					     list) {
-		vlan_vid_del(dev, vid_info->vid);
+		vlan_vid_del(dev, vid_info->proto, vid_info->vid);
 	}
 
 	return err;
@@ -363,7 +368,7 @@
 		return;
 
 	list_for_each_entry(vid_info, &vlan_info->vid_list, list)
-		vlan_vid_del(dev, vid_info->vid);
+		vlan_vid_del(dev, vid_info->proto, vid_info->vid);
 }
 EXPORT_SYMBOL(vlan_vids_del_by_dev);