2006-04-10 kwc@citi.umich.edu
[nfs-utils.git] / utils / gssd / gssd_proc.c
index 1e7ebae..bac0520 100644 (file)
@@ -178,6 +178,12 @@ fail:
 static void
 destroy_client(struct clnt_info *clp)
 {
+       if (clp->krb5_poll_index != -1)
+               memset(&pollarray[clp->krb5_poll_index], 0,
+                                       sizeof(struct pollfd));
+       if (clp->spkm3_poll_index != -1)
+               memset(&pollarray[clp->spkm3_poll_index], 0,
+                                       sizeof(struct pollfd));
        if (clp->dir_fd != -1) close(clp->dir_fd);
        if (clp->krb5_fd != -1) close(clp->krb5_fd);
        if (clp->spkm3_fd != -1) close(clp->spkm3_fd);
@@ -216,15 +222,20 @@ process_clnt_dir_files(struct clnt_info * clp)
        char    sname[32];
        char    info_file_name[32];
 
-       snprintf(kname, sizeof(kname), "%s/krb5", clp->dirname);
-       clp->krb5_fd = open(kname, O_RDWR);
-       snprintf(sname, sizeof(sname), "%s/spkm3", clp->dirname);
-       clp->spkm3_fd = open(sname, O_RDWR);
+       if (clp->krb5_fd == -1) {
+               snprintf(kname, sizeof(kname), "%s/krb5", clp->dirname);
+               clp->krb5_fd = open(kname, O_RDWR);
+       }
+       if (clp->spkm3_fd == -1) {
+               snprintf(sname, sizeof(sname), "%s/spkm3", clp->dirname);
+               clp->spkm3_fd = open(sname, O_RDWR);
+       }
        if((clp->krb5_fd == -1) && (clp->spkm3_fd == -1))
                return -1;
        snprintf(info_file_name, sizeof(info_file_name), "%s/info",
                        clp->dirname);
-       if (read_service_info(info_file_name, &clp->servicename,
+       if ((clp->servicename == NULL) &&
+            read_service_info(info_file_name, &clp->servicename,
                                &clp->servername, &clp->prog, &clp->vers,
                                &clp->protocol))
                return -1;
@@ -250,6 +261,31 @@ get_poll_index(int *ind)
        return 0;
 }
 
+
+static int
+insert_clnt_poll(struct clnt_info *clp)
+{
+       if ((clp->krb5_fd != -1) && (clp->krb5_poll_index == -1)) {
+               if (get_poll_index(&clp->krb5_poll_index)) {
+                       printerr(0, "ERROR: Too many krb5 clients\n");
+                       return -1;
+               }
+               pollarray[clp->krb5_poll_index].fd = clp->krb5_fd;
+               pollarray[clp->krb5_poll_index].events |= POLLIN;
+       }
+
+       if ((clp->spkm3_fd != -1) && (clp->spkm3_poll_index == -1)) {
+               if (get_poll_index(&clp->spkm3_poll_index)) {
+                       printerr(0, "ERROR: Too many spkm3 clients\n");
+                       return -1;
+               }
+               pollarray[clp->spkm3_poll_index].fd = clp->spkm3_fd;
+               pollarray[clp->spkm3_poll_index].events |= POLLIN;
+       }
+
+       return 0;
+}
+
 static void
 process_clnt_dir(char *dir)
 {
@@ -273,23 +309,8 @@ process_clnt_dir(char *dir)
        if (process_clnt_dir_files(clp))
                goto fail_keep_client;
 
-       if(clp->krb5_fd != -1) {
-               if (get_poll_index(&clp->krb5_poll_index)) {
-                       printerr(0, "ERROR: Too many krb5 clients\n");
-                       goto fail_destroy_client;
-               }
-               pollarray[clp->krb5_poll_index].fd = clp->krb5_fd;
-               pollarray[clp->krb5_poll_index].events |= POLLIN;
-       }
-
-       if(clp->spkm3_fd != -1) {
-               if (get_poll_index(&clp->spkm3_poll_index)) {
-                       printerr(0, "ERROR: Too many spkm3 clients\n");
-                       goto fail_destroy_client;
-               }
-               pollarray[clp->spkm3_poll_index].fd = clp->spkm3_fd;
-               pollarray[clp->spkm3_poll_index].events |= POLLIN;
-       }
+       if (insert_clnt_poll(clp))
+               goto fail_destroy_client;
 
        return;
 
@@ -314,18 +335,50 @@ init_client_list(void)
        pollarray = calloc(pollsize, sizeof(struct pollfd));
 }
 
+/*
+ * This is run after a DNOTIFY signal, and should clear up any
+ * directories that are no longer around, and re-scan any existing
+ * directories, since the DNOTIFY could have been in there.
+ */
 static void
-destroy_client_list(void)
+update_old_clients(struct dirent **namelist, int size)
 {
-       struct clnt_info        *clp;
+       struct clnt_info *clp;
+       void *saveprev;
+       int i, stillhere;
+
+       for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
+               stillhere = 0;
+               for (i=0; i < size; i++) {
+                       if (!strcmp(clp->dirname, namelist[i]->d_name)) {
+                               stillhere = 1;
+                               break;
+                       }
+               }
+               if (!stillhere) {
+                       printerr(2, "destroying client %s\n", clp->dirname);
+                       saveprev = clp->list.tqe_prev;
+                       TAILQ_REMOVE(&clnt_list, clp, list);
+                       destroy_client(clp);
+                       clp = saveprev;
+               }
+       }
+       for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
+               if (!process_clnt_dir_files(clp))
+                       insert_clnt_poll(clp);
+       }
+}
 
-       printerr(1, "processing client list\n");
+/* Search for a client by directory name, return 1 if found, 0 otherwise */
+static int
+find_client(char *dirname)
+{
+       struct clnt_info        *clp;
 
-       while (clnt_list.tqh_first != NULL) {
-               clp = clnt_list.tqh_first;
-               TAILQ_REMOVE(&clnt_list, clp, list);
-               destroy_client(clp);
-       }
+       for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next)
+               if (!strcmp(clp->dirname, dirname))
+                       return 1;
+       return 0;
 }
 
 /* Used to read (and re-read) list of clients, set up poll array. */
@@ -333,9 +386,7 @@ int
 update_client_list(void)
 {
        struct dirent **namelist;
-       int i,j;
-
-       destroy_client_list();
+       int i, j;
 
        if (chdir(pipefsdir) < 0) {
                printerr(0, "ERROR: can't chdir to %s: %s\n",
@@ -343,17 +394,17 @@ update_client_list(void)
                return -1;
        }
 
-       memset(pollarray, 0, pollsize * sizeof(struct pollfd));
-
        j = scandir(pipefsdir, &namelist, NULL, alphasort);
        if (j < 0) {
                printerr(0, "ERROR: can't scandir %s: %s\n",
                         pipefsdir, strerror(errno));
                return -1;
        }
+       update_old_clients(namelist, j);
        for (i=0; i < j; i++) {
                if (i < FD_ALLOC_BLOCK
-                               && !strncmp(namelist[i]->d_name, "clnt", 4))
+                               && !strncmp(namelist[i]->d_name, "clnt", 4)
+                               && !find_client(namelist[i]->d_name))
                        process_clnt_dir(namelist[i]->d_name);
                free(namelist[i]);
        }
@@ -366,21 +417,31 @@ static int
 do_downcall(int k5_fd, uid_t uid, struct authgss_private_data *pd,
            gss_buffer_desc *context_token)
 {
-       char    buf[2048];
-       char    *p = buf, *end = buf + 2048;
+       char    *buf = NULL, *p = NULL, *end = NULL;
        unsigned int timeout = 0; /* XXX decide on a reasonable value */
+       unsigned int buf_size = 0;
 
        printerr(1, "doing downcall\n");
+       buf_size = sizeof(uid) + sizeof(timeout) + sizeof(pd->pd_seq_win) +
+               sizeof(pd->pd_ctx_hndl.length) + pd->pd_ctx_hndl.length +
+               sizeof(context_token->length) + context_token->length;
+       p = buf = malloc(buf_size);
+       end = buf + buf_size;
 
-       if (WRITE_BYTES(&p, end, uid)) return -1;
+       if (WRITE_BYTES(&p, end, uid)) goto out_err;
        /* Not setting any timeout for now: */
-       if (WRITE_BYTES(&p, end, timeout)) return -1;
-       if (WRITE_BYTES(&p, end, pd->pd_seq_win)) return -1;
-       if (write_buffer(&p, end, &pd->pd_ctx_hndl)) return -1;
-       if (write_buffer(&p, end, context_token)) return -1;
+       if (WRITE_BYTES(&p, end, timeout)) goto out_err;
+       if (WRITE_BYTES(&p, end, pd->pd_seq_win)) goto out_err;
+       if (write_buffer(&p, end, &pd->pd_ctx_hndl)) goto out_err;
+       if (write_buffer(&p, end, context_token)) goto out_err;
 
-       if (write(k5_fd, buf, p - buf) < p - buf) return -1;
+       if (write(k5_fd, buf, p - buf) < p - buf) goto out_err;
+       if (buf) free(buf);
        return 0;
+out_err:
+       if (buf) free(buf);
+       printerr(0, "Failed to write downcall!\n");
+       return -1;
 }
 
 static int
@@ -393,15 +454,17 @@ do_error_downcall(int k5_fd, uid_t uid, int err)
 
        printerr(1, "doing error downcall\n");
 
-       if (WRITE_BYTES(&p, end, uid)) return -1;
-       if (WRITE_BYTES(&p, end, timeout)) return -1;
+       if (WRITE_BYTES(&p, end, uid)) goto out_err;
+       if (WRITE_BYTES(&p, end, timeout)) goto out_err;
        /* use seq_win = 0 to indicate an error: */
-       if (WRITE_BYTES(&p, end, zero)) return -1;
-       if (WRITE_BYTES(&p, end, err)) return -1;
+       if (WRITE_BYTES(&p, end, zero)) goto out_err;
+       if (WRITE_BYTES(&p, end, err)) goto out_err;
 
-       if (write(k5_fd, buf, p - buf) < p - buf) return -1;
+       if (write(k5_fd, buf, p - buf) < p - buf) goto out_err;
        return 0;
-
+out_err:
+       printerr(0, "Failed to write error downcall!\n");
+       return -1;
 }
 
 /*
@@ -409,6 +472,7 @@ do_error_downcall(int k5_fd, uid_t uid, int err)
  * gss context with a server.
  */
 int create_auth_rpc_client(struct clnt_info *clp,
+                          CLIENT **clnt_return,
                           AUTH **auth_return,
                           uid_t uid,
                           int authtype)
@@ -418,7 +482,24 @@ int create_auth_rpc_client(struct clnt_info *clp,
        AUTH                    *auth = NULL;
        uid_t                   save_uid = -1;
        int                     retval = -1;
+       int                     errcode;
        OM_uint32               min_stat;
+       char                    rpc_errmsg[1024];
+       int                     sockp = RPC_ANYSOCK;
+       int                     sendsz = 32768, recvsz = 32768;
+       struct addrinfo         ai_hints, *a = NULL;
+       char                    service[64];
+       char                    *at_sign;
+
+       /* Create the context as the user (not as root) */
+       save_uid = geteuid();
+       if (seteuid(uid) != 0) {
+               printerr(0, "WARNING: Failed to seteuid for "
+                           "user with uid %d\n", uid);
+               goto out_fail;
+       }
+       printerr(2, "creating context using euid %d (save_uid %d)\n",
+                       geteuid(), save_uid);
 
        sec.qop = GSS_C_QOP_DEFAULT;
        sec.svc = RPCSEC_GSS_SVC_NONE;
@@ -430,7 +511,10 @@ int create_auth_rpc_client(struct clnt_info *clp,
        }
        else if (authtype == AUTHTYPE_SPKM3) {
                sec.mech = (gss_OID)&spkm3oid;
-               sec.req_flags = GSS_C_ANON_FLAG;
+               /* XXX sec.req_flags = GSS_C_ANON_FLAG;
+                * Need a way to switch....
+                */
+               sec.req_flags = GSS_C_MUTUAL_FLAG;
        }
        else {
                printerr(0, "ERROR: Invalid authentication type (%d) "
@@ -454,56 +538,125 @@ int create_auth_rpc_client(struct clnt_info *clp,
 #endif
        }
 
-       /* Create the context as the user (not as root) */
-       save_uid = geteuid();
-       if (seteuid(uid) != 0) {
-               printerr(0, "WARNING: Failed to seteuid for "
-                           "user with uid %d\n", uid);
-               goto out_fail;
-       }
-       printerr(2, "creating context using euid %d (save_uid %d)\n",
-                       geteuid(), save_uid);
-
        /* create an rpc connection to the nfs server */
 
        printerr(2, "creating %s client for server %s\n", clp->protocol,
                        clp->servername);
-       if ((rpc_clnt = clnt_create(clp->servername, clp->prog, clp->vers,
-                                       clp->protocol)) == NULL) {
-               printerr(0, "WARNING: can't create rpc_clnt for server "
-                           "%s for user with uid %d\n",
-                       clp->servername, uid);
+
+       memset(&ai_hints, '\0', sizeof(ai_hints));
+       ai_hints.ai_family = PF_INET;
+       ai_hints.ai_flags |= AI_CANONNAME;
+       if ((strcmp(clp->protocol, "tcp")) == 0) {
+               ai_hints.ai_socktype = SOCK_STREAM;
+               ai_hints.ai_protocol = IPPROTO_TCP;
+       } else if ((strcmp(clp->protocol, "udp")) == 0) {
+               ai_hints.ai_socktype = SOCK_DGRAM;
+               ai_hints.ai_protocol = IPPROTO_UDP;
+       } else {
+               printerr(0, "WARNING: unrecognized protocol, '%s', requested "
+                        "for connection to server %s for user with uid %d",
+                        clp->protocol, clp->servername, uid);
+               goto out_fail;
+       }
+
+       /* extract the service name from clp->servicename */
+       if ((at_sign = strchr(clp->servicename, '@')) == NULL) {
+               printerr(0, "WARNING: servicename (%s) not formatted as "
+                       "expected with service@host", clp->servicename);
+               goto out_fail;
+       }
+       if ((at_sign - clp->servicename) >= sizeof(service)) {
+               printerr(0, "WARNING: service portion of servicename (%s) "
+                       "is too long!", clp->servicename);
                goto out_fail;
        }
+       strncpy(service, clp->servicename, at_sign - clp->servicename);
+       service[at_sign - clp->servicename] = '\0';
+
+       errcode = getaddrinfo(clp->servername, service, &ai_hints, &a);
+       if (errcode) {
+               printerr(0, "WARNING: Error from getaddrinfo for server "
+                        "'%s': %s", clp->servername, gai_strerror(errcode));
+               goto out_fail;
+       }
+
+       if (a == NULL) {
+               printerr(0, "WARNING: No address information found for "
+                        "connection to server %s for user with uid %d",
+                        clp->servername, uid);
+               goto out_fail;
+       }
+       if (a->ai_protocol == IPPROTO_TCP) {
+               if ((rpc_clnt = clnttcp_create(
+                                       (struct sockaddr_in *) a->ai_addr,
+                                       clp->prog, clp->vers, &sockp,
+                                       sendsz, recvsz)) == NULL) {
+                       snprintf(rpc_errmsg, sizeof(rpc_errmsg),
+                                "WARNING: can't create tcp rpc_clnt "
+                                "for server %s for user with uid %d",
+                                clp->servername, uid);
+                       printerr(0, "%s\n",
+                                clnt_spcreateerror(rpc_errmsg));
+                       goto out_fail;
+               }
+       } else if (a->ai_protocol == IPPROTO_UDP) {
+               const struct timeval timeout = {5, 0};
+               if ((rpc_clnt = clntudp_bufcreate(
+                                       (struct sockaddr_in *) a->ai_addr,
+                                       clp->prog, clp->vers, timeout,
+                                       &sockp, sendsz, recvsz)) == NULL) {
+                       snprintf(rpc_errmsg, sizeof(rpc_errmsg),
+                                "WARNING: can't create udp rpc_clnt "
+                                "for server %s for user with uid %d",
+                                clp->servername, uid);
+                       printerr(0, "%s\n",
+                                clnt_spcreateerror(rpc_errmsg));
+                       goto out_fail;
+               }
+       } else {
+               /* Shouldn't happen! */
+               printerr(0, "ERROR: requested protocol '%s', but "
+                        "got addrinfo with protocol %d",
+                        clp->protocol, a->ai_protocol);
+               goto out_fail;
+       }
+       /* We're done with this */
+       freeaddrinfo(a);
+       a = NULL;
 
        printerr(2, "creating context with server %s\n", clp->servicename);
        auth = authgss_create_default(rpc_clnt, clp->servicename, &sec);
        if (!auth) {
                /* Our caller should print appropriate message */
-               printerr(2, "WARNING: Failed to create krb5 context for "
+               printerr(2, "WARNING: Failed to create %s context for "
                            "user with uid %d for server %s\n",
+                       (authtype == AUTHTYPE_KRB5 ? "krb5":"spkm3"),
                         uid, clp->servername);
                goto out_fail;
        }
 
-       /* Restore euid to original value */
-       if (seteuid(save_uid) != 0) {
-               printerr(0, "WARNING: Failed to restore euid"
-                           " to uid %d\n", save_uid);
-               goto out_fail;
-       }
-       save_uid = -1;
-
        /* Success !!! */
+       rpc_clnt->cl_auth = auth;
+       *clnt_return = rpc_clnt;
        *auth_return = auth;
        retval = 0;
 
-  out_fail:
+  out:
        if (sec.cred != GSS_C_NO_CREDENTIAL)
                gss_release_cred(&min_stat, &sec.cred);
+       if (a != NULL) freeaddrinfo(a);
+       /* Restore euid to original value */
+       if ((save_uid != -1) && (seteuid(save_uid) != 0)) {
+               printerr(0, "WARNING: Failed to restore euid"
+                           " to uid %d\n", save_uid);
+       }
+       return retval;
+
+  out_fail:
+       /* Only destroy here if failure.  Otherwise, caller is responsible */
        if (rpc_clnt) clnt_destroy(rpc_clnt);
 
-       return retval;
+       goto out;
 }
 
 
@@ -515,7 +668,8 @@ void
 handle_krb5_upcall(struct clnt_info *clp)
 {
        uid_t                   uid;
-       AUTH                    *auth;
+       CLIENT                  *rpc_clnt = NULL;
+       AUTH                    *auth = NULL;
        struct authgss_private_data pd;
        gss_buffer_desc         token;
        char                    **credlist = NULL;
@@ -525,6 +679,7 @@ handle_krb5_upcall(struct clnt_info *clp)
 
        token.length = 0;
        token.value = NULL;
+       memset(&pd, 0, sizeof(struct authgss_private_data));
 
        if (read(clp->krb5_fd, &uid, sizeof(uid)) < sizeof(uid)) {
                printerr(0, "WARNING: failed reading uid from krb5 "
@@ -547,7 +702,7 @@ handle_krb5_upcall(struct clnt_info *clp)
                }
                for (ccname = credlist; ccname && *ccname; ccname++) {
                        gssd_setup_krb5_machine_gss_ccache(*ccname);
-                       if ((create_auth_rpc_client(clp, &auth, uid,
+                       if ((create_auth_rpc_client(clp, &rpc_clnt, &auth, uid,
                                                    AUTHTYPE_KRB5)) == 0) {
                                /* Success! */
                                success++;
@@ -571,7 +726,8 @@ handle_krb5_upcall(struct clnt_info *clp)
                /* Tell krb5 gss which credentials cache to use */
                gssd_setup_krb5_user_gss_ccache(uid, clp->servername);
 
-               if (create_auth_rpc_client(clp, &auth, uid, AUTHTYPE_KRB5)) {
+               if ((create_auth_rpc_client(clp, &rpc_clnt, &auth, uid,
+                                                       AUTHTYPE_KRB5)) != 0) {
                        printerr(0, "WARNING: Failed to create krb5 context "
                                    "for user with uid %d for server %s\n",
                                 uid, clp->servername);
@@ -586,7 +742,7 @@ handle_krb5_upcall(struct clnt_info *clp)
                goto out_return_error;
        }
 
-       if (serialize_context_for_kernel(pd.pd_ctx, &token)) {
+       if (serialize_context_for_kernel(pd.pd_ctx, &token, &krb5oid)) {
                printerr(0, "WARNING: Failed to serialize krb5 context for "
                            "user with uid %d for server %s\n",
                         uid, clp->servername);
@@ -595,14 +751,20 @@ handle_krb5_upcall(struct clnt_info *clp)
 
        do_downcall(clp->krb5_fd, uid, &pd, &token);
 
+out:
        if (token.value)
                free(token.value);
-out:
+       if (pd.pd_ctx_hndl.length != 0)
+               authgss_free_private_data(&pd);
+       if (auth)
+               AUTH_DESTROY(auth);
+       if (rpc_clnt)
+               clnt_destroy(rpc_clnt);
        return;
 
 out_return_error:
        do_error_downcall(clp->krb5_fd, uid, -1);
-       return;
+       goto out;
 }
 
 /*
@@ -613,7 +775,8 @@ void
 handle_spkm3_upcall(struct clnt_info *clp)
 {
        uid_t                   uid;
-       AUTH                    *auth;
+       CLIENT                  *rpc_clnt = NULL;
+       AUTH                    *auth = NULL;
        struct authgss_private_data pd;
        gss_buffer_desc         token;
 
@@ -628,7 +791,7 @@ handle_spkm3_upcall(struct clnt_info *clp)
                goto out;
        }
 
-       if (create_auth_rpc_client(clp, &auth, uid, AUTHTYPE_SPKM3)) {
+       if (create_auth_rpc_client(clp, &rpc_clnt, &auth, uid, AUTHTYPE_SPKM3)) {
                printerr(0, "WARNING: Failed to create spkm3 context for "
                            "user with uid %d\n", uid);
                goto out_return_error;
@@ -641,7 +804,7 @@ handle_spkm3_upcall(struct clnt_info *clp)
                goto out_return_error;
        }
 
-       if (serialize_context_for_kernel(pd.pd_ctx, &token)) {
+       if (serialize_context_for_kernel(pd.pd_ctx, &token, &spkm3oid)) {
                printerr(0, "WARNING: Failed to serialize spkm3 context for "
                            "user with uid %d for server\n",
                         uid, clp->servername);
@@ -650,12 +813,16 @@ handle_spkm3_upcall(struct clnt_info *clp)
 
        do_downcall(clp->spkm3_fd, uid, &pd, &token);
 
+out:
        if (token.value)
                free(token.value);
-out:
+       if (auth)
+               AUTH_DESTROY(auth);
+       if (rpc_clnt)
+               clnt_destroy(rpc_clnt);
        return;
 
 out_return_error:
        do_error_downcall(clp->spkm3_fd, uid, -1);
-       return;
+       goto out;
 }