]> git.decadent.org.uk Git - nfs-utils.git/blobdiff - utils/svcgssd/svcgssd_proc.c
Updates from Kevin Coffman at UMich
[nfs-utils.git] / utils / svcgssd / svcgssd_proc.c
index c2470c60f42604f6a8bf5880db1c3a520d6c5b3d..b43a023882675d6788fe758010fe4e7dddb18f95 100644 (file)
@@ -44,6 +44,7 @@
 #include <string.h>
 #include <fcntl.h>
 #include <errno.h>
+#include <nfsidmap.h>
 
 #include "svcgssd.h"
 #include "gss_util.h"
 #include "context.h"
 #include "cacheio.h"
 
-/* XXX: ? */
-#ifndef NGROUPS
-#define NGROUPS 32
-#endif
-
 extern char * mech2file(gss_OID mech);
 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.rpcsec.context/channel"
 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.rpcsec.init/channel"
 
+#define TOKEN_BUF_SIZE         8192
+
 struct svc_cred {
        uid_t   cr_uid;
        gid_t   cr_gid;
+       int     cr_ngroups;
        gid_t   cr_groups[NGROUPS];
 };
 
@@ -71,7 +70,7 @@ do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
                gss_OID mech, gss_buffer_desc *context_token)
 {
        FILE *f;
-       int i, ngroups;
+       int i;
        char *fname = NULL;
 
        printerr(1, "doing downcall\n");
@@ -89,15 +88,8 @@ do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
        qword_printint(f, 0x7fffffff); /*XXX need a better timeout */
        qword_printint(f, cred->cr_uid);
        qword_printint(f, cred->cr_gid);
-       ngroups = NGROUPS;
-       for (i=0; i < NGROUPS; i++) {
-               if (cred->cr_groups[i] == NOGROUP) {
-                       ngroups = i;
-                       break;
-               }
-       }
-       qword_printint(f, ngroups);
-       for (i=0; i < ngroups; i++)
+       qword_printint(f, cred->cr_ngroups);
+       for (i=0; i < cred->cr_ngroups; i++)
                qword_printint(f, cred->cr_groups[i]);
        qword_print(f, fname);
        qword_printhex(f, context_token->value, context_token->length);
@@ -121,7 +113,7 @@ send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
              u_int32_t maj_stat, u_int32_t min_stat,
              gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
 {
-       char buf[2 * 4096];
+       char buf[2 * TOKEN_BUF_SIZE];
        char *bp = buf;
        int blen = sizeof(buf);
        /* XXXARG: */
@@ -167,42 +159,76 @@ send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
 #define rpcsec_gsserr_credproblem      13
 #define rpcsec_gsserr_ctxproblem       14
 
-/* XXX memory leaks everywhere: */
+static void
+add_supplementary_groups(char *secname, char *name, struct svc_cred *cred)
+{
+       int ret;
+       static gid_t *groups = NULL;
+
+       cred->cr_ngroups = NGROUPS;
+       ret = nfs4_gss_princ_to_grouplist(secname, name,
+                       cred->cr_groups, &cred->cr_ngroups);
+       if (ret < 0) {
+               groups = realloc(groups, cred->cr_ngroups*sizeof(gid_t));
+               ret = nfs4_gss_princ_to_grouplist(secname, name,
+                               groups, &cred->cr_ngroups);
+               if (ret < 0)
+                       cred->cr_ngroups = 0;
+               else {
+                       if (cred->cr_ngroups > NGROUPS)
+                               cred->cr_ngroups = NGROUPS;
+                       memcpy(cred->cr_groups, groups,
+                                       cred->cr_ngroups*sizeof(gid_t));
+               }
+       }
+}
+
 static int
-get_ids(gss_name_t client_name, gss_OID *mech, struct svc_cred *cred)
+get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred)
 {
        u_int32_t       maj_stat, min_stat;
        gss_buffer_desc name;
        char            *sname;
        int             res = -1;
-       struct passwd   *pw = NULL;
-       gss_OID         name_type;
-       char            *c;
+       uid_t           uid, gid;
+       gss_OID         name_type = GSS_C_NO_OID;
+       char            *secname;
+       gid_t           *groups;
 
        maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
-       if (maj_stat != GSS_S_COMPLETE)
+       if (maj_stat != GSS_S_COMPLETE) {
+               pgsserr("get_ids: gss_display_name",
+                       maj_stat, min_stat, mech);
                goto out;
-       if (!(sname = calloc(name.length + 1, 1)))
+       }
+       if (!(sname = calloc(name.length + 1, 1))) {
+               printerr(0, "WARNING: get_ids: error allocating %d bytes "
+                       "for sname\n", name.length + 1);
                goto out;
+       }
        memcpy(sname, name.value, name.length);
        printerr(1, "sname = %s\n", sname);
-       /* XXX: should use same mapping as idmapd?  Or something; for now
-        * I'm just chopping off the domain. */
-       /* XXX: note that idmapd also does this!  It doesn't check the domain
-        * name. */
-       if ((c = strchr(sname, '@')) != NULL)
-               *c = '\0';
-       /* XXX? mapping unknown users (including machine creds) to nobody: */
-       if ( !(pw = getpwnam(sname)) && !(pw = getpwnam("nobody")) )
-               goto out;
-       cred->cr_uid = pw->pw_uid;
-       cred->cr_gid = pw->pw_gid;
-       /* XXX Read password file?  Use initgroups? I dunno...*/
-       cred->cr_groups[0] = NOGROUP;
+
+       res = -EINVAL;
+       if ((secname = mech2file(mech)) == NULL) {
+               printerr(0, "WARNING: get_ids: error mapping mech to "
+                       "file for name '%s'\n", sname);
+               goto out_free;
+       }
+       nfs4_init_name_mapping(NULL); /* XXX: should only do this once */
+       res = nfs4_gss_princ_to_ids(secname, sname, &uid, &gid);
+       if (res < 0) {
+               printerr(0, "WARNING: get_ids: unable to map "
+                       "name '%s' to a uid\n", sname);
+               goto out_free;
+       }
+       cred->cr_uid = uid;
+       cred->cr_gid = gid;
+       add_supplementary_groups(secname, sname, cred);
        res = 0;
+out_free:
+       free(sname);
 out:
-       if (res)
-               printerr(0, "WARNING: get_uid failed\n");
        return res;
 }
 
@@ -248,7 +274,7 @@ handle_nullreq(FILE *f) {
        /* XXX initialize to a random integer to reduce chances of unnecessary
         * invalidation of existing ctx's on restarting svcgssd. */
        static u_int32_t        handle_seq = 0;
-       char                    in_tok_buf[1023];
+       char                    in_tok_buf[TOKEN_BUF_SIZE];
        char                    in_handle_buf[15];
        char                    out_handle_buf[15];
        gss_buffer_desc         in_tok = {.value = in_tok_buf},
@@ -261,7 +287,7 @@ handle_nullreq(FILE *f) {
        u_int32_t               ret_flags;
        gss_ctx_id_t            ctx = GSS_C_NO_CONTEXT;
        gss_name_t              client_name;
-       gss_OID                 mech;
+       gss_OID                 mech = GSS_C_NO_OID;
        u_int32_t               maj_stat = GSS_S_FAILURE, min_stat = 0;
        struct svc_cred         cred;
        static char             *lbuf = NULL;
@@ -278,15 +304,13 @@ handle_nullreq(FILE *f) {
 
        cp = lbuf;
 
-       in_handle.length
-               = qword_get(&cp, in_handle.value, sizeof(in_handle_buf));
+       in_handle.length = (size_t) qword_get(&cp, in_handle.value,
+                                             sizeof(in_handle_buf));
        printerr(2, "in_handle: \n");
        print_hexl(2, in_handle.value, in_handle.length);
-       handle_seq++;
-       out_handle.length = sizeof(handle_seq);
-       memcpy(out_handle.value, &handle_seq, sizeof(handle_seq));
 
-       in_tok.length = qword_get(&cp, in_tok.value, sizeof(in_tok_buf));
+       in_tok.length = (size_t) qword_get(&cp, in_tok.value,
+                                          sizeof(in_tok_buf));
        printerr(2, "in_tok: \n");
        print_hexl(2, in_tok.value, in_tok.length);
 
@@ -297,48 +321,68 @@ handle_nullreq(FILE *f) {
        }
 
        if (in_handle.length != 0) { /* CONTINUE_INIT case */
-               printerr(0, "WARNING: handle_nullreq: "
-                           "CONTINUE_INIT unsupported\n");
-               send_response(f, &in_handle, &in_tok, -1, -1, &null_token,
-                               &null_token);
-               goto out_err;
+               if (in_handle.length != sizeof(ctx)) {
+                       printerr(0, "WARNING: handle_nullreq: "
+                                   "input handle has unexpected length %d\n",
+                                   in_handle.length);
+                       goto out_err;
+               }
+               /* in_handle is the context id stored in the out_handle
+                * for the GSS_S_CONTINUE_NEEDED case below.  */
+               memcpy(&ctx, in_handle.value, in_handle.length);
        }
 
        maj_stat = gss_accept_sec_context(&min_stat, &ctx, gssd_creds,
                        &in_tok, GSS_C_NO_CHANNEL_BINDINGS, &client_name,
                        &mech, &out_tok, &ret_flags, NULL, NULL);
-       if (maj_stat != GSS_S_COMPLETE) {
+
+       if (maj_stat == GSS_S_CONTINUE_NEEDED) {
+               printerr(1, "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
+
+               /* Save the context handle for future calls */
+               out_handle.length = sizeof(ctx);
+               memcpy(out_handle.value, &ctx, sizeof(ctx));
+               goto continue_needed;
+       }
+       else if (maj_stat != GSS_S_COMPLETE) {
                printerr(0, "WARNING: gss_accept_sec_context failed\n");
                pgsserr("handle_nullreq: gss_accept_sec_context",
                        maj_stat, min_stat, mech);
-               send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
-                               &null_token, &null_token);
                goto out_err;
        }
-       if (get_ids(client_name, &mech, &cred)) {
-               printerr(0, "WARNING: handle_nullreq: get_uid failed\n");
-               send_response(f, &in_handle, &in_tok, GSS_S_BAD_NAME /* XXX? */,
-                               0, &null_token, &null_token);
+       if (get_ids(client_name, mech, &cred)) {
+               /* get_ids() prints error msg */
+               maj_stat = GSS_S_BAD_NAME; /* XXX ? */
                goto out_err;
        }
 
+
+       /* Context complete. Pass handle_seq in out_handle to use
+        * for context lookup in the kernel. */
+       handle_seq++;
+       out_handle.length = sizeof(handle_seq);
+       memcpy(out_handle.value, &handle_seq, sizeof(handle_seq));
+
        /* kernel needs ctx to calculate verifier on null response, so
         * must give it context before doing null call: */
        if (serialize_context_for_kernel(ctx, &ctx_token)) {
                printerr(0, "WARNING: handle_nullreq: "
                            "serialize_context_for_kernel failed\n");
-               send_response(f, &in_handle, &in_tok, -1, /* XXX? */
-                               0, &null_token, &null_token);
+               maj_stat = GSS_S_FAILURE;
                goto out_err;
        }
        do_svc_downcall(&out_handle, &cred, mech, &ctx_token);
+continue_needed:
        send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
                        &out_handle, &out_tok);
-       goto out;
-out_err:
 out:
        if (ctx_token.value != NULL)
                free(ctx_token.value);
        printerr(1, "finished handling null request\n");
        return;
+
+out_err:
+       send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
+                       &null_token, &null_token);
+       goto out;
 }