]> Pileus Git - ~andy/linux/blobdiff - net/netfilter/ipvs/ip_vs_proto.c
ipvs: use GFP_KERNEL allocation where possible
[~andy/linux] / net / netfilter / ipvs / ip_vs_proto.c
index f843a88332509edc9ac7ed509cba6d679685664f..e91c8982dfac37915953631380fbf02d5c98cb44 100644 (file)
@@ -48,7 +48,7 @@ static struct ip_vs_protocol *ip_vs_proto_table[IP_VS_PROTO_TAB_SIZE];
  */
 static int __used __init register_ip_vs_protocol(struct ip_vs_protocol *pp)
 {
-       unsigned hash = IP_VS_PROTO_HASH(pp->protocol);
+       unsigned int hash = IP_VS_PROTO_HASH(pp->protocol);
 
        pp->next = ip_vs_proto_table[hash];
        ip_vs_proto_table[hash] = pp;
@@ -59,9 +59,6 @@ static int __used __init register_ip_vs_protocol(struct ip_vs_protocol *pp)
        return 0;
 }
 
-#if defined(CONFIG_IP_VS_PROTO_TCP) || defined(CONFIG_IP_VS_PROTO_UDP) || \
-    defined(CONFIG_IP_VS_PROTO_SCTP) || defined(CONFIG_IP_VS_PROTO_AH) || \
-    defined(CONFIG_IP_VS_PROTO_ESP)
 /*
  *     register an ipvs protocols netns related data
  */
@@ -69,9 +66,9 @@ static int
 register_ip_vs_proto_netns(struct net *net, struct ip_vs_protocol *pp)
 {
        struct netns_ipvs *ipvs = net_ipvs(net);
-       unsigned hash = IP_VS_PROTO_HASH(pp->protocol);
+       unsigned int hash = IP_VS_PROTO_HASH(pp->protocol);
        struct ip_vs_proto_data *pd =
-                       kzalloc(sizeof(struct ip_vs_proto_data), GFP_ATOMIC);
+                       kzalloc(sizeof(struct ip_vs_proto_data), GFP_KERNEL);
 
        if (!pd)
                return -ENOMEM;
@@ -81,12 +78,18 @@ register_ip_vs_proto_netns(struct net *net, struct ip_vs_protocol *pp)
        ipvs->proto_data_table[hash] = pd;
        atomic_set(&pd->appcnt, 0);     /* Init app counter */
 
-       if (pp->init_netns != NULL)
-               pp->init_netns(net, pd);
+       if (pp->init_netns != NULL) {
+               int ret = pp->init_netns(net, pd);
+               if (ret) {
+                       /* unlink an free proto data */
+                       ipvs->proto_data_table[hash] = pd->next;
+                       kfree(pd);
+                       return ret;
+               }
+       }
 
        return 0;
 }
-#endif
 
 /*
  *     unregister an ipvs protocol
@@ -94,7 +97,7 @@ register_ip_vs_proto_netns(struct net *net, struct ip_vs_protocol *pp)
 static int unregister_ip_vs_protocol(struct ip_vs_protocol *pp)
 {
        struct ip_vs_protocol **pp_p;
-       unsigned hash = IP_VS_PROTO_HASH(pp->protocol);
+       unsigned int hash = IP_VS_PROTO_HASH(pp->protocol);
 
        pp_p = &ip_vs_proto_table[hash];
        for (; *pp_p; pp_p = &(*pp_p)->next) {
@@ -117,7 +120,7 @@ unregister_ip_vs_proto_netns(struct net *net, struct ip_vs_proto_data *pd)
 {
        struct netns_ipvs *ipvs = net_ipvs(net);
        struct ip_vs_proto_data **pd_p;
-       unsigned hash = IP_VS_PROTO_HASH(pd->pp->protocol);
+       unsigned int hash = IP_VS_PROTO_HASH(pd->pp->protocol);
 
        pd_p = &ipvs->proto_data_table[hash];
        for (; *pd_p; pd_p = &(*pd_p)->next) {
@@ -139,7 +142,7 @@ unregister_ip_vs_proto_netns(struct net *net, struct ip_vs_proto_data *pd)
 struct ip_vs_protocol * ip_vs_proto_get(unsigned short proto)
 {
        struct ip_vs_protocol *pp;
-       unsigned hash = IP_VS_PROTO_HASH(proto);
+       unsigned int hash = IP_VS_PROTO_HASH(proto);
 
        for (pp = ip_vs_proto_table[hash]; pp; pp = pp->next) {
                if (pp->protocol == proto)
@@ -157,7 +160,7 @@ struct ip_vs_proto_data *
 __ipvs_proto_data_get(struct netns_ipvs *ipvs, unsigned short proto)
 {
        struct ip_vs_proto_data *pd;
-       unsigned hash = IP_VS_PROTO_HASH(proto);
+       unsigned int hash = IP_VS_PROTO_HASH(proto);
 
        for (pd = ipvs->proto_data_table[hash]; pd; pd = pd->next) {
                if (pd->pp->protocol == proto)
@@ -196,7 +199,7 @@ void ip_vs_protocol_timeout_change(struct netns_ipvs *ipvs, int flags)
 int *
 ip_vs_create_timeout_table(int *table, int size)
 {
-       return kmemdup(table, size, GFP_ATOMIC);
+       return kmemdup(table, size, GFP_KERNEL);
 }
 
 
@@ -316,22 +319,35 @@ ip_vs_tcpudp_debug_packet(int af, struct ip_vs_protocol *pp,
  */
 int __net_init ip_vs_protocol_net_init(struct net *net)
 {
+       int i, ret;
+       static struct ip_vs_protocol *protos[] = {
 #ifdef CONFIG_IP_VS_PROTO_TCP
-       register_ip_vs_proto_netns(net, &ip_vs_protocol_tcp);
+        &ip_vs_protocol_tcp,
 #endif
 #ifdef CONFIG_IP_VS_PROTO_UDP
-       register_ip_vs_proto_netns(net, &ip_vs_protocol_udp);
+       &ip_vs_protocol_udp,
 #endif
 #ifdef CONFIG_IP_VS_PROTO_SCTP
-       register_ip_vs_proto_netns(net, &ip_vs_protocol_sctp);
+       &ip_vs_protocol_sctp,
 #endif
 #ifdef CONFIG_IP_VS_PROTO_AH
-       register_ip_vs_proto_netns(net, &ip_vs_protocol_ah);
+       &ip_vs_protocol_ah,
 #endif
 #ifdef CONFIG_IP_VS_PROTO_ESP
-       register_ip_vs_proto_netns(net, &ip_vs_protocol_esp);
+       &ip_vs_protocol_esp,
 #endif
+       };
+
+       for (i = 0; i < ARRAY_SIZE(protos); i++) {
+               ret = register_ip_vs_proto_netns(net, protos[i]);
+               if (ret < 0)
+                       goto cleanup;
+       }
        return 0;
+
+cleanup:
+       ip_vs_protocol_net_cleanup(net);
+       return ret;
 }
 
 void __net_exit ip_vs_protocol_net_cleanup(struct net *net)