]> git.decadent.org.uk Git - nfs-utils.git/blobdiff - utils/svcgssd/svcgssd_proc.c
Updates from Kevin Coffman at UMich
[nfs-utils.git] / utils / svcgssd / svcgssd_proc.c
index dfa3c4c8d7b69f4f1cfcaf2ff3585bc859ae88c2..b43a023882675d6788fe758010fe4e7dddb18f95 100644 (file)
@@ -56,6 +56,8 @@ extern char * mech2file(gss_OID mech);
 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.rpcsec.context/channel"
 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.rpcsec.init/channel"
 
+#define TOKEN_BUF_SIZE         8192
+
 struct svc_cred {
        uid_t   cr_uid;
        gid_t   cr_gid;
@@ -111,7 +113,7 @@ send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
              u_int32_t maj_stat, u_int32_t min_stat,
              gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
 {
-       char buf[2 * 4096];
+       char buf[2 * TOKEN_BUF_SIZE];
        char *bp = buf;
        int blen = sizeof(buf);
        /* XXXARG: */
@@ -189,25 +191,37 @@ get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred)
        char            *sname;
        int             res = -1;
        uid_t           uid, gid;
-       gss_OID         name_type;
+       gss_OID         name_type = GSS_C_NO_OID;
        char            *secname;
        gid_t           *groups;
 
        maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
-       if (maj_stat != GSS_S_COMPLETE)
+       if (maj_stat != GSS_S_COMPLETE) {
+               pgsserr("get_ids: gss_display_name",
+                       maj_stat, min_stat, mech);
                goto out;
-       if (!(sname = calloc(name.length + 1, 1)))
+       }
+       if (!(sname = calloc(name.length + 1, 1))) {
+               printerr(0, "WARNING: get_ids: error allocating %d bytes "
+                       "for sname\n", name.length + 1);
                goto out;
+       }
        memcpy(sname, name.value, name.length);
        printerr(1, "sname = %s\n", sname);
 
        res = -EINVAL;
-       if ((secname = mech2file(mech)) == NULL)
+       if ((secname = mech2file(mech)) == NULL) {
+               printerr(0, "WARNING: get_ids: error mapping mech to "
+                       "file for name '%s'\n", sname);
                goto out_free;
+       }
        nfs4_init_name_mapping(NULL); /* XXX: should only do this once */
        res = nfs4_gss_princ_to_ids(secname, sname, &uid, &gid);
-       if (res < 0)
+       if (res < 0) {
+               printerr(0, "WARNING: get_ids: unable to map "
+                       "name '%s' to a uid\n", sname);
                goto out_free;
+       }
        cred->cr_uid = uid;
        cred->cr_gid = gid;
        add_supplementary_groups(secname, sname, cred);
@@ -215,8 +229,6 @@ get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred)
 out_free:
        free(sname);
 out:
-       if (res)
-               printerr(0, "WARNING: get_uid failed\n");
        return res;
 }
 
@@ -262,7 +274,7 @@ handle_nullreq(FILE *f) {
        /* XXX initialize to a random integer to reduce chances of unnecessary
         * invalidation of existing ctx's on restarting svcgssd. */
        static u_int32_t        handle_seq = 0;
-       char                    in_tok_buf[8192];
+       char                    in_tok_buf[TOKEN_BUF_SIZE];
        char                    in_handle_buf[15];
        char                    out_handle_buf[15];
        gss_buffer_desc         in_tok = {.value = in_tok_buf},
@@ -275,7 +287,7 @@ handle_nullreq(FILE *f) {
        u_int32_t               ret_flags;
        gss_ctx_id_t            ctx = GSS_C_NO_CONTEXT;
        gss_name_t              client_name;
-       gss_OID                 mech;
+       gss_OID                 mech = GSS_C_NO_OID;
        u_int32_t               maj_stat = GSS_S_FAILURE, min_stat = 0;
        struct svc_cred         cred;
        static char             *lbuf = NULL;
@@ -296,9 +308,6 @@ handle_nullreq(FILE *f) {
                                              sizeof(in_handle_buf));
        printerr(2, "in_handle: \n");
        print_hexl(2, in_handle.value, in_handle.length);
-       handle_seq++;
-       out_handle.length = sizeof(handle_seq);
-       memcpy(out_handle.value, &handle_seq, sizeof(handle_seq));
 
        in_tok.length = (size_t) qword_get(&cp, in_tok.value,
                                           sizeof(in_tok_buf));
@@ -312,26 +321,48 @@ handle_nullreq(FILE *f) {
        }
 
        if (in_handle.length != 0) { /* CONTINUE_INIT case */
-               printerr(0, "WARNING: handle_nullreq: "
-                           "CONTINUE_INIT unsupported\n");
-               goto out_err;
+               if (in_handle.length != sizeof(ctx)) {
+                       printerr(0, "WARNING: handle_nullreq: "
+                                   "input handle has unexpected length %d\n",
+                                   in_handle.length);
+                       goto out_err;
+               }
+               /* in_handle is the context id stored in the out_handle
+                * for the GSS_S_CONTINUE_NEEDED case below.  */
+               memcpy(&ctx, in_handle.value, in_handle.length);
        }
 
        maj_stat = gss_accept_sec_context(&min_stat, &ctx, gssd_creds,
                        &in_tok, GSS_C_NO_CHANNEL_BINDINGS, &client_name,
                        &mech, &out_tok, &ret_flags, NULL, NULL);
-       if (maj_stat != GSS_S_COMPLETE) {
+
+       if (maj_stat == GSS_S_CONTINUE_NEEDED) {
+               printerr(1, "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
+
+               /* Save the context handle for future calls */
+               out_handle.length = sizeof(ctx);
+               memcpy(out_handle.value, &ctx, sizeof(ctx));
+               goto continue_needed;
+       }
+       else if (maj_stat != GSS_S_COMPLETE) {
                printerr(0, "WARNING: gss_accept_sec_context failed\n");
                pgsserr("handle_nullreq: gss_accept_sec_context",
                        maj_stat, min_stat, mech);
                goto out_err;
        }
        if (get_ids(client_name, mech, &cred)) {
-               printerr(0, "WARNING: handle_nullreq: get_uid failed\n");
+               /* get_ids() prints error msg */
                maj_stat = GSS_S_BAD_NAME; /* XXX ? */
                goto out_err;
        }
 
+
+       /* Context complete. Pass handle_seq in out_handle to use
+        * for context lookup in the kernel. */
+       handle_seq++;
+       out_handle.length = sizeof(handle_seq);
+       memcpy(out_handle.value, &handle_seq, sizeof(handle_seq));
+
        /* kernel needs ctx to calculate verifier on null response, so
         * must give it context before doing null call: */
        if (serialize_context_for_kernel(ctx, &ctx_token)) {
@@ -341,15 +372,17 @@ handle_nullreq(FILE *f) {
                goto out_err;
        }
        do_svc_downcall(&out_handle, &cred, mech, &ctx_token);
+continue_needed:
        send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
                        &out_handle, &out_tok);
-       goto out;
-out_err:
-       send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
-                       &null_token, &null_token);
 out:
        if (ctx_token.value != NULL)
                free(ctx_token.value);
        printerr(1, "finished handling null request\n");
        return;
+
+out_err:
+       send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
+                       &null_token, &null_token);
+       goto out;
 }