]> Pileus Git - ~andy/linux/blobdiff - net/sunrpc/auth_gss/auth_gss.c
Merge branch 'slab/for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/penber...
[~andy/linux] / net / sunrpc / auth_gss / auth_gss.c
index 7da6b457f66abfab016fd8b21aeedcb14d5e7ff0..fc2f78d6a9b46fae51a3ba366bc50c23e9238101 100644 (file)
@@ -52,6 +52,8 @@
 #include <linux/sunrpc/gss_api.h>
 #include <asm/uaccess.h>
 
+#include "../netns.h"
+
 static const struct rpc_authops authgss_ops;
 
 static const struct rpc_credops gss_credops;
@@ -85,8 +87,6 @@ struct gss_auth {
 };
 
 /* pipe_version >= 0 if and only if someone has a pipe open. */
-static int pipe_version = -1;
-static atomic_t pipe_users = ATOMIC_INIT(0);
 static DEFINE_SPINLOCK(pipe_version_lock);
 static struct rpc_wait_queue pipe_version_rpc_waitqueue;
 static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
@@ -266,24 +266,27 @@ struct gss_upcall_msg {
        char databuf[UPCALL_BUF_LEN];
 };
 
-static int get_pipe_version(void)
+static int get_pipe_version(struct net *net)
 {
+       struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
        int ret;
 
        spin_lock(&pipe_version_lock);
-       if (pipe_version >= 0) {
-               atomic_inc(&pipe_users);
-               ret = pipe_version;
+       if (sn->pipe_version >= 0) {
+               atomic_inc(&sn->pipe_users);
+               ret = sn->pipe_version;
        } else
                ret = -EAGAIN;
        spin_unlock(&pipe_version_lock);
        return ret;
 }
 
-static void put_pipe_version(void)
+static void put_pipe_version(struct net *net)
 {
-       if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
-               pipe_version = -1;
+       struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+       if (atomic_dec_and_lock(&sn->pipe_users, &pipe_version_lock)) {
+               sn->pipe_version = -1;
                spin_unlock(&pipe_version_lock);
        }
 }
@@ -291,9 +294,10 @@ static void put_pipe_version(void)
 static void
 gss_release_msg(struct gss_upcall_msg *gss_msg)
 {
+       struct net *net = rpc_net_ns(gss_msg->auth->client);
        if (!atomic_dec_and_test(&gss_msg->count))
                return;
-       put_pipe_version();
+       put_pipe_version(net);
        BUG_ON(!list_empty(&gss_msg->list));
        if (gss_msg->ctx != NULL)
                gss_put_ctx(gss_msg->ctx);
@@ -439,7 +443,10 @@ static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
                                struct rpc_clnt *clnt,
                                const char *service_name)
 {
-       if (pipe_version == 0)
+       struct net *net = rpc_net_ns(clnt);
+       struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
+
+       if (sn->pipe_version == 0)
                gss_encode_v0_msg(gss_msg);
        else /* pipe_version == 1 */
                gss_encode_v1_msg(gss_msg, clnt, service_name);
@@ -455,7 +462,7 @@ gss_alloc_msg(struct gss_auth *gss_auth, struct rpc_clnt *clnt,
        gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
        if (gss_msg == NULL)
                return ERR_PTR(-ENOMEM);
-       vers = get_pipe_version();
+       vers = get_pipe_version(rpc_net_ns(clnt));
        if (vers < 0) {
                kfree(gss_msg);
                return ERR_PTR(vers);
@@ -559,24 +566,34 @@ out:
 static inline int
 gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
 {
+       struct net *net = rpc_net_ns(gss_auth->client);
+       struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
        struct rpc_pipe *pipe;
        struct rpc_cred *cred = &gss_cred->gc_base;
        struct gss_upcall_msg *gss_msg;
+       unsigned long timeout;
        DEFINE_WAIT(wait);
-       int err = 0;
+       int err;
 
        dprintk("RPC:       %s for uid %u\n",
                __func__, from_kuid(&init_user_ns, cred->cr_uid));
 retry:
+       err = 0;
+       /* Default timeout is 15s unless we know that gssd is not running */
+       timeout = 15 * HZ;
+       if (!sn->gssd_running)
+               timeout = HZ >> 2;
        gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
        if (PTR_ERR(gss_msg) == -EAGAIN) {
                err = wait_event_interruptible_timeout(pipe_version_waitqueue,
-                               pipe_version >= 0, 15*HZ);
-               if (pipe_version < 0) {
+                               sn->pipe_version >= 0, timeout);
+               if (sn->pipe_version < 0) {
+                       if (err == 0)
+                               sn->gssd_running = 0;
                        warn_gssd();
                        err = -EACCES;
                }
-               if (err)
+               if (err < 0)
                        goto out;
                goto retry;
        }
@@ -707,20 +724,22 @@ out:
 
 static int gss_pipe_open(struct inode *inode, int new_version)
 {
+       struct net *net = inode->i_sb->s_fs_info;
+       struct sunrpc_net *sn = net_generic(net, sunrpc_net_id);
        int ret = 0;
 
        spin_lock(&pipe_version_lock);
-       if (pipe_version < 0) {
+       if (sn->pipe_version < 0) {
                /* First open of any gss pipe determines the version: */
-               pipe_version = new_version;
+               sn->pipe_version = new_version;
                rpc_wake_up(&pipe_version_rpc_waitqueue);
                wake_up(&pipe_version_waitqueue);
-       } else if (pipe_version != new_version) {
+       } else if (sn->pipe_version != new_version) {
                /* Trying to open a pipe of a different version */
                ret = -EBUSY;
                goto out;
        }
-       atomic_inc(&pipe_users);
+       atomic_inc(&sn->pipe_users);
 out:
        spin_unlock(&pipe_version_lock);
        return ret;
@@ -740,6 +759,7 @@ static int gss_pipe_open_v1(struct inode *inode)
 static void
 gss_pipe_release(struct inode *inode)
 {
+       struct net *net = inode->i_sb->s_fs_info;
        struct rpc_pipe *pipe = RPC_I(inode)->pipe;
        struct gss_upcall_msg *gss_msg;
 
@@ -758,7 +778,7 @@ restart:
        }
        spin_unlock(&pipe->lock);
 
-       put_pipe_version();
+       put_pipe_version(net);
 }
 
 static void