]> git.decadent.org.uk Git - nfs-utils.git/blobdiff - utils/gssd/gssd_proc.c
nfs-utils: switch gssd to use standard function for getting an RPC client
[nfs-utils.git] / utils / gssd / gssd_proc.c
index 416653be28144238b858f6049b3a51157a31df4c..1d55d8d1a58267080c4b160d4b2a763348b3f7c1 100644 (file)
 
 */
 
+#ifdef HAVE_CONFIG_H
+#include <config.h>
+#endif /* HAVE_CONFIG_H */
+
 #ifndef _GNU_SOURCE
 #define _GNU_SOURCE
 #endif
-#include "config.h"
+
 #include <sys/param.h>
 #include <rpc/rpc.h>
 #include <sys/stat.h>
 #include <sys/socket.h>
 #include <arpa/inet.h>
+#include <sys/fsuid.h>
 
 #include <stdio.h>
 #include <stdlib.h>
@@ -65,9 +70,9 @@
 #include "gssd.h"
 #include "err_util.h"
 #include "gss_util.h"
-#include "gss_oids.h"
 #include "krb5_util.h"
 #include "context.h"
+#include "nfsrpc.h"
 
 /*
  * pollarray:
  *     with an index into pollarray[], and other basic data about that client.
  *
  * Directory structure: created by the kernel nfs client
- *      /pipefsdir/clntXX             : one per rpc_clnt struct in the kernel
- *      /pipefsdir/clntXX/krb5        : read uid for which kernel wants
- *                                      a context, write the resulting context
- *      /pipefsdir/clntXX/info        : stores info such as server name
+ *      {pipefs_nfsdir}/clntXX             : one per rpc_clnt struct in the kernel
+ *      {pipefs_nfsdir}/clntXX/krb5        : read uid for which kernel wants
+ *                                         a context, write the resulting context
+ *      {pipefs_nfsdir}/clntXX/info        : stores info such as server name
  *
  * Algorithm:
- *      Poll all /pipefsdir/clntXX/krb5 files.  When ready, data read
+ *      Poll all {pipefs_nfsdir}/clntXX/krb5 files.  When ready, data read
  *      is a uid; performs rpcsec_gss context initialization protocol to
  *      get a cred for that user.  Writes result to corresponding krb5 file
  *      in a form the kernel code will understand.
  *      In addition, we make sure we are notified whenever anything is
- *      created or destroyed in pipefsdir/ or in an of the clntXX directories,
- *      and rescan the whole pipefsdir when this happens.
+ *      created or destroyed in {pipefs_nfsdir} or in an of the clntXX directories,
+ *      and rescan the whole {pipefs_nfsdir} when this happens.
  */
 
 struct pollfd * pollarray;
 
 int pollsize;  /* the size of pollaray (in pollfd's) */
 
+/*
+ * convert a presentation address string to a sockaddr_storage struct. Returns
+ * true on success and false on failure.
+ */
+static int
+addrstr_to_sockaddr(struct sockaddr *sa, const char *addr, const int port)
+{
+       struct sockaddr_in      *s4 = (struct sockaddr_in *) sa;
+
+       if (inet_pton(AF_INET, addr, &s4->sin_addr)) {
+               s4->sin_family = AF_INET;
+               s4->sin_port = htons(port);
+       } else {
+               printerr(0, "ERROR: unable to convert %s to address\n", addr);
+               return 0;
+       }
+
+       return 1;
+}
+
+/*
+ * convert a sockaddr to a hostname
+ */
+static char *
+sockaddr_to_hostname(const struct sockaddr *sa, const char *addr)
+{
+       socklen_t               addrlen;
+       int                     err;
+       char                    *hostname;
+       char                    hbuf[NI_MAXHOST];
+
+       switch (sa->sa_family) {
+       case AF_INET:
+               addrlen = sizeof(struct sockaddr_in);
+               break;
+       default:
+               printerr(0, "ERROR: unrecognized addr family %d\n",
+                        sa->sa_family);
+               return NULL;
+       }
+
+       err = getnameinfo(sa, addrlen, hbuf, sizeof(hbuf), NULL, 0,
+                         NI_NAMEREQD);
+       if (err) {
+               printerr(0, "ERROR: unable to resolve %s to hostname: %s\n",
+                        addr, err == EAI_SYSTEM ? strerror(err) :
+                                                  gai_strerror(err));
+               return NULL;
+       }
+
+       hostname = strdup(hbuf);
+
+       return hostname;
+}
+
 /* XXX buffer problems: */
 static int
 read_service_info(char *info_file_name, char **servicename, char **servername,
-                 int *prog, int *vers, char **protocol) {
+                 int *prog, int *vers, char **protocol,
+                 struct sockaddr *addr) {
 #define INFOBUFLEN 256
-       char            buf[INFOBUFLEN];
+       char            buf[INFOBUFLEN + 1];
        static char     dummy[128];
        int             nbytes;
        static char     service[128];
@@ -111,10 +172,11 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
        char            program[16];
        char            version[16];
        char            protoname[16];
-       in_addr_t       inaddr;
+       char            cb_port[128];
+       char            *p;
        int             fd = -1;
-       struct hostent  *ent = NULL;
        int             numfields;
+       int             port = 0;
 
        *servicename = *servername = *protocol = NULL;
 
@@ -126,11 +188,12 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
        if ((nbytes = read(fd, buf, INFOBUFLEN)) == -1)
                goto fail;
        close(fd);
+       buf[nbytes] = '\0';
 
-       numfields = sscanf(buf,"RPC server: %s\n"
-                  "service: %s %s version %s\n"
-                  "address: %s\n"
-                  "protocol: %s\n",
+       numfields = sscanf(buf,"RPC server: %127s\n"
+                  "service: %127s %15s version %15s\n"
+                  "address: %127s\n"
+                  "protocol: %15s\n",
                   dummy,
                   service, program, version,
                   address,
@@ -142,6 +205,10 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
                goto fail;
        }
 
+       cb_port[0] = '\0';
+       if ((p = strstr(buf, "port")) != NULL)
+               sscanf(p, "port: %127s\n", cb_port);
+
        /* check service, program, and version */
        if(memcmp(service, "nfs", 3)) return -1;
        *prog = atoi(program + 1); /* skip open paren */
@@ -149,16 +216,23 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
        if((*prog != 100003) || ((*vers != 2) && (*vers != 3) && (*vers != 4)))
                goto fail;
 
-       /* create service name */
-       inaddr = inet_addr(address);
-       if (!(ent = gethostbyaddr(&inaddr, sizeof(inaddr), AF_INET))) {
-               printerr(0, "ERROR: can't resolve server %s name\n", address);
-               goto fail;
+       if (cb_port[0] != '\0') {
+               port = atoi(cb_port);
+               if (port < 0 || port > 65535)
+                       goto fail;
        }
-       if (!(*servername = calloc(strlen(ent->h_name) + 1, 1)))
+
+       if (!addrstr_to_sockaddr(addr, address, port))
+               goto fail;
+
+       *servername = sockaddr_to_hostname(addr, address);
+       if (*servername == NULL)
                goto fail;
-       memcpy(*servername, ent->h_name, strlen(ent->h_name));
-       snprintf(buf, INFOBUFLEN, "%s@%s", service, ent->h_name);
+
+       nbytes = snprintf(buf, INFOBUFLEN, "%s@%s", service, *servername);
+       if (nbytes > INFOBUFLEN)
+               goto fail;
+
        if (!(*servicename = calloc(strlen(buf) + 1, 1)))
                goto fail;
        memcpy(*servicename, buf, strlen(buf));
@@ -169,22 +243,29 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
 fail:
        printerr(0, "ERROR: failed to read service info\n");
        if (fd != -1) close(fd);
-       if (*servername) free(*servername);
-       if (*servicename) free(*servicename);
-       if (*protocol) free(*protocol);
+       free(*servername);
+       free(*servicename);
+       free(*protocol);
+       *servicename = *servername = *protocol = NULL;
        return -1;
 }
 
 static void
 destroy_client(struct clnt_info *clp)
 {
+       if (clp->krb5_poll_index != -1)
+               memset(&pollarray[clp->krb5_poll_index], 0,
+                                       sizeof(struct pollfd));
+       if (clp->spkm3_poll_index != -1)
+               memset(&pollarray[clp->spkm3_poll_index], 0,
+                                       sizeof(struct pollfd));
        if (clp->dir_fd != -1) close(clp->dir_fd);
        if (clp->krb5_fd != -1) close(clp->krb5_fd);
        if (clp->spkm3_fd != -1) close(clp->spkm3_fd);
-       if (clp->dirname) free(clp->dirname);
-       if (clp->servicename) free(clp->servicename);
-       if (clp->servername) free(clp->servername);
-       if (clp->protocol) free(clp->protocol);
+       free(clp->dirname);
+       free(clp->servicename);
+       free(clp->servername);
+       free(clp->protocol);
        free(clp);
 }
 
@@ -216,17 +297,22 @@ process_clnt_dir_files(struct clnt_info * clp)
        char    sname[32];
        char    info_file_name[32];
 
-       snprintf(kname, sizeof(kname), "%s/krb5", clp->dirname);
-       clp->krb5_fd = open(kname, O_RDWR);
-       snprintf(sname, sizeof(sname), "%s/spkm3", clp->dirname);
-       clp->spkm3_fd = open(sname, O_RDWR);
+       if (clp->krb5_fd == -1) {
+               snprintf(kname, sizeof(kname), "%s/krb5", clp->dirname);
+               clp->krb5_fd = open(kname, O_RDWR);
+       }
+       if (clp->spkm3_fd == -1) {
+               snprintf(sname, sizeof(sname), "%s/spkm3", clp->dirname);
+               clp->spkm3_fd = open(sname, O_RDWR);
+       }
        if((clp->krb5_fd == -1) && (clp->spkm3_fd == -1))
                return -1;
        snprintf(info_file_name, sizeof(info_file_name), "%s/info",
                        clp->dirname);
-       if (read_service_info(info_file_name, &clp->servicename,
+       if ((clp->servicename == NULL) &&
+            read_service_info(info_file_name, &clp->servicename,
                                &clp->servername, &clp->prog, &clp->vers,
-                               &clp->protocol))
+                               &clp->protocol, (struct sockaddr *) &clp->addr))
                return -1;
        return 0;
 }
@@ -250,6 +336,31 @@ get_poll_index(int *ind)
        return 0;
 }
 
+
+static int
+insert_clnt_poll(struct clnt_info *clp)
+{
+       if ((clp->krb5_fd != -1) && (clp->krb5_poll_index == -1)) {
+               if (get_poll_index(&clp->krb5_poll_index)) {
+                       printerr(0, "ERROR: Too many krb5 clients\n");
+                       return -1;
+               }
+               pollarray[clp->krb5_poll_index].fd = clp->krb5_fd;
+               pollarray[clp->krb5_poll_index].events |= POLLIN;
+       }
+
+       if ((clp->spkm3_fd != -1) && (clp->spkm3_poll_index == -1)) {
+               if (get_poll_index(&clp->spkm3_poll_index)) {
+                       printerr(0, "ERROR: Too many spkm3 clients\n");
+                       return -1;
+               }
+               pollarray[clp->spkm3_poll_index].fd = clp->spkm3_fd;
+               pollarray[clp->spkm3_poll_index].events |= POLLIN;
+       }
+
+       return 0;
+}
+
 static void
 process_clnt_dir(char *dir)
 {
@@ -273,23 +384,8 @@ process_clnt_dir(char *dir)
        if (process_clnt_dir_files(clp))
                goto fail_keep_client;
 
-       if(clp->krb5_fd != -1) {
-               if (get_poll_index(&clp->krb5_poll_index)) {
-                       printerr(0, "ERROR: Too many krb5 clients\n");
-                       goto fail_destroy_client;
-               }
-               pollarray[clp->krb5_poll_index].fd = clp->krb5_fd;
-               pollarray[clp->krb5_poll_index].events |= POLLIN;
-       }
-
-       if(clp->spkm3_fd != -1) {
-               if (get_poll_index(&clp->spkm3_poll_index)) {
-                       printerr(0, "ERROR: Too many spkm3 clients\n");
-                       goto fail_destroy_client;
-               }
-               pollarray[clp->spkm3_poll_index].fd = clp->spkm3_fd;
-               pollarray[clp->spkm3_poll_index].events |= POLLIN;
-       }
+       if (insert_clnt_poll(clp))
+               goto fail_destroy_client;
 
        return;
 
@@ -314,18 +410,50 @@ init_client_list(void)
        pollarray = calloc(pollsize, sizeof(struct pollfd));
 }
 
+/*
+ * This is run after a DNOTIFY signal, and should clear up any
+ * directories that are no longer around, and re-scan any existing
+ * directories, since the DNOTIFY could have been in there.
+ */
 static void
-destroy_client_list(void)
+update_old_clients(struct dirent **namelist, int size)
 {
-       struct clnt_info        *clp;
+       struct clnt_info *clp;
+       void *saveprev;
+       int i, stillhere;
+
+       for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
+               stillhere = 0;
+               for (i=0; i < size; i++) {
+                       if (!strcmp(clp->dirname, namelist[i]->d_name)) {
+                               stillhere = 1;
+                               break;
+                       }
+               }
+               if (!stillhere) {
+                       printerr(2, "destroying client %s\n", clp->dirname);
+                       saveprev = clp->list.tqe_prev;
+                       TAILQ_REMOVE(&clnt_list, clp, list);
+                       destroy_client(clp);
+                       clp = saveprev;
+               }
+       }
+       for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
+               if (!process_clnt_dir_files(clp))
+                       insert_clnt_poll(clp);
+       }
+}
 
-       printerr(1, "processing client list\n");
+/* Search for a client by directory name, return 1 if found, 0 otherwise */
+static int
+find_client(char *dirname)
+{
+       struct clnt_info        *clp;
 
-       while (clnt_list.tqh_first != NULL) {
-               clp = clnt_list.tqh_first;
-               TAILQ_REMOVE(&clnt_list, clp, list);
-               destroy_client(clp);
-       }
+       for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next)
+               if (!strcmp(clp->dirname, dirname))
+                       return 1;
+       return 0;
 }
 
 /* Used to read (and re-read) list of clients, set up poll array. */
@@ -333,27 +461,25 @@ int
 update_client_list(void)
 {
        struct dirent **namelist;
-       int i,j;
-
-       destroy_client_list();
+       int i, j;
 
-       if (chdir(pipefsdir) < 0) {
+       if (chdir(pipefs_nfsdir) < 0) {
                printerr(0, "ERROR: can't chdir to %s: %s\n",
-                        pipefsdir, strerror(errno));
+                        pipefs_nfsdir, strerror(errno));
                return -1;
        }
 
-       memset(pollarray, 0, pollsize * sizeof(struct pollfd));
-
-       j = scandir(pipefsdir, &namelist, NULL, alphasort);
+       j = scandir(pipefs_nfsdir, &namelist, NULL, alphasort);
        if (j < 0) {
                printerr(0, "ERROR: can't scandir %s: %s\n",
-                        pipefsdir, strerror(errno));
+                        pipefs_nfsdir, strerror(errno));
                return -1;
        }
+       update_old_clients(namelist, j);
        for (i=0; i < j; i++) {
                if (i < FD_ALLOC_BLOCK
-                               && !strncmp(namelist[i]->d_name, "clnt", 4))
+                               && !strncmp(namelist[i]->d_name, "clnt", 4)
+                               && !find_client(namelist[i]->d_name))
                        process_clnt_dir(namelist[i]->d_name);
                free(namelist[i]);
        }
@@ -366,23 +492,29 @@ static int
 do_downcall(int k5_fd, uid_t uid, struct authgss_private_data *pd,
            gss_buffer_desc *context_token)
 {
-       char    buf[2048];
-       char    *p = buf, *end = buf + 2048;
-       unsigned int timeout = 0; /* XXX decide on a reasonable value */
+       char    *buf = NULL, *p = NULL, *end = NULL;
+       unsigned int timeout = context_timeout;
+       unsigned int buf_size = 0;
 
        printerr(1, "doing downcall\n");
+       buf_size = sizeof(uid) + sizeof(timeout) + sizeof(pd->pd_seq_win) +
+               sizeof(pd->pd_ctx_hndl.length) + pd->pd_ctx_hndl.length +
+               sizeof(context_token->length) + context_token->length;
+       p = buf = malloc(buf_size);
+       end = buf + buf_size;
 
        if (WRITE_BYTES(&p, end, uid)) goto out_err;
-       /* Not setting any timeout for now: */
        if (WRITE_BYTES(&p, end, timeout)) goto out_err;
        if (WRITE_BYTES(&p, end, pd->pd_seq_win)) goto out_err;
        if (write_buffer(&p, end, &pd->pd_ctx_hndl)) goto out_err;
        if (write_buffer(&p, end, context_token)) goto out_err;
 
        if (write(k5_fd, buf, p - buf) < p - buf) goto out_err;
+       if (buf) free(buf);
        return 0;
 out_err:
-       printerr(0, "Failed to write downcall!\n");
+       if (buf) free(buf);
+       printerr(1, "Failed to write downcall!\n");
        return -1;
 }
 
@@ -405,15 +537,84 @@ do_error_downcall(int k5_fd, uid_t uid, int err)
        if (write(k5_fd, buf, p - buf) < p - buf) goto out_err;
        return 0;
 out_err:
-       printerr(0, "Failed to write error downcall!\n");
+       printerr(1, "Failed to write error downcall!\n");
        return -1;
 }
 
+/*
+ * If the port isn't already set, do an rpcbind query to the remote server
+ * using the program and version and get the port. 
+ *
+ * Newer kernels send the value of the port= mount option in the "info"
+ * file for the upcall or '0' for NFSv2/3. For NFSv4 it sends the value
+ * of the port= option or '2049'. The port field in a new sockaddr should
+ * reflect the value that was sent by the kernel.
+ */
+static int
+populate_port(struct sockaddr *sa, const socklen_t salen,
+             const rpcprog_t program, const rpcvers_t version,
+             const unsigned short protocol)
+{
+       struct sockaddr_in      *s4 = (struct sockaddr_in *) sa;
+       unsigned short          port;
+
+       /*
+        * Newer kernels send the port in the upcall. If we already have
+        * the port, there's no need to look it up.
+        */
+       switch (sa->sa_family) {
+       case AF_INET:
+               if (s4->sin_port != 0) {
+                       printerr(2, "DEBUG: port already set to %d\n",
+                                ntohs(s4->sin_port));
+                       return 1;
+               }
+               break;
+       default:
+               printerr(0, "ERROR: unsupported address family %d\n",
+                           sa->sa_family);
+               return 0;
+       }
+
+       /*
+        * Newer kernels that send the port in the upcall set the value to
+        * 2049 for NFSv4 mounts when one isn't specified. The check below is
+        * only for kernels that don't send the port in the upcall. For those
+        * we either have to do an rpcbind query or set it to the standard
+        * port. Doing a query could be problematic (firewalls, etc), so take
+        * the latter approach.
+        */
+       if (program == 100003 && version == 4) {
+               port = 2049;
+               goto set_port;
+       }
+
+       port = nfs_getport(sa, salen, program, version, protocol);
+       if (!port) {
+               printerr(0, "ERROR: unable to obtain port for prog %ld "
+                           "vers %ld\n", program, version);
+               return 0;
+       }
+
+set_port:
+       printerr(2, "DEBUG: setting port to %hu for prog %lu vers %lu\n", port,
+                program, version);
+
+       switch (sa->sa_family) {
+       case AF_INET:
+               s4->sin_port = htons(port);
+               break;
+       }
+
+       return 1;
+}
+
 /*
  * Create an RPC connection and establish an authenticated
  * gss context with a server.
  */
 int create_auth_rpc_client(struct clnt_info *clp,
+                          CLIENT **clnt_return,
                           AUTH **auth_return,
                           uid_t uid,
                           int authtype)
@@ -424,6 +625,21 @@ int create_auth_rpc_client(struct clnt_info *clp,
        uid_t                   save_uid = -1;
        int                     retval = -1;
        OM_uint32               min_stat;
+       char                    rpc_errmsg[1024];
+       int                     protocol;
+       struct timeval          timeout = {5, 0};
+       struct sockaddr         *addr = (struct sockaddr *) &clp->addr;
+       socklen_t               salen;
+
+       /* Create the context as the user (not as root) */
+       save_uid = geteuid();
+       if (setfsuid(uid) != 0) {
+               printerr(0, "WARNING: Failed to setfsuid for "
+                           "user with uid %d\n", uid);
+               goto out_fail;
+       }
+       printerr(2, "creating context using fsuid %d (save_uid %d)\n",
+                       uid, save_uid);
 
        sec.qop = GSS_C_QOP_DEFAULT;
        sec.svc = RPCSEC_GSS_SVC_NONE;
@@ -435,7 +651,10 @@ int create_auth_rpc_client(struct clnt_info *clp,
        }
        else if (authtype == AUTHTYPE_SPKM3) {
                sec.mech = (gss_OID)&spkm3oid;
-               sec.req_flags = GSS_C_ANON_FLAG;
+               /* XXX sec.req_flags = GSS_C_ANON_FLAG;
+                * Need a way to switch....
+                */
+               sec.req_flags = GSS_C_MUTUAL_FLAG;
        }
        else {
                printerr(0, "ERROR: Invalid authentication type (%d) "
@@ -459,25 +678,45 @@ int create_auth_rpc_client(struct clnt_info *clp,
 #endif
        }
 
-       /* Create the context as the user (not as root) */
-       save_uid = geteuid();
-       if (seteuid(uid) != 0) {
-               printerr(0, "WARNING: Failed to seteuid for "
-                           "user with uid %d\n", uid);
-               goto out_fail;
-       }
-       printerr(2, "creating context using euid %d (save_uid %d)\n",
-                       geteuid(), save_uid);
-
        /* create an rpc connection to the nfs server */
 
        printerr(2, "creating %s client for server %s\n", clp->protocol,
                        clp->servername);
-       if ((rpc_clnt = clnt_create(clp->servername, clp->prog, clp->vers,
-                                       clp->protocol)) == NULL) {
-               printerr(0, "WARNING: can't create rpc_clnt for server "
-                           "%s for user with uid %d\n",
-                       clp->servername, uid);
+
+       if ((strcmp(clp->protocol, "tcp")) == 0) {
+               protocol = IPPROTO_TCP;
+       } else if ((strcmp(clp->protocol, "udp")) == 0) {
+               protocol = IPPROTO_UDP;
+       } else {
+               printerr(0, "WARNING: unrecognized protocol, '%s', requested "
+                        "for connection to server %s for user with uid %d\n",
+                        clp->protocol, clp->servername, uid);
+               goto out_fail;
+       }
+
+       switch (addr->sa_family) {
+       case AF_INET:
+               salen = sizeof(struct sockaddr_in);
+               break;
+       default:
+               printerr(1, "ERROR: Unknown address family %d\n",
+                        addr->sa_family);
+               goto out_fail;
+       }
+
+       if (!populate_port(addr, salen, clp->prog, clp->vers, protocol))
+               goto out_fail;
+
+       rpc_clnt = nfs_get_rpcclient(addr, salen, protocol, clp->prog,
+                                    clp->vers, &timeout);
+       if (!rpc_clnt) {
+               snprintf(rpc_errmsg, sizeof(rpc_errmsg),
+                        "WARNING: can't create %s rpc_clnt to server %s for "
+                        "user with uid %d",
+                        protocol == IPPROTO_TCP ? "tcp" : "udp",
+                        clp->servername, uid);
+               printerr(0, "%s\n",
+                        clnt_spcreateerror(rpc_errmsg));
                goto out_fail;
        }
 
@@ -485,30 +724,34 @@ int create_auth_rpc_client(struct clnt_info *clp,
        auth = authgss_create_default(rpc_clnt, clp->servicename, &sec);
        if (!auth) {
                /* Our caller should print appropriate message */
-               printerr(2, "WARNING: Failed to create krb5 context for "
+               printerr(2, "WARNING: Failed to create %s context for "
                            "user with uid %d for server %s\n",
+                       (authtype == AUTHTYPE_KRB5 ? "krb5":"spkm3"),
                         uid, clp->servername);
                goto out_fail;
        }
 
-       /* Restore euid to original value */
-       if (seteuid(save_uid) != 0) {
-               printerr(0, "WARNING: Failed to restore euid"
-                           " to uid %d\n", save_uid);
-               goto out_fail;
-       }
-       save_uid = -1;
-
        /* Success !!! */
+       rpc_clnt->cl_auth = auth;
+       *clnt_return = rpc_clnt;
        *auth_return = auth;
        retval = 0;
 
-  out_fail:
+  out:
        if (sec.cred != GSS_C_NO_CREDENTIAL)
                gss_release_cred(&min_stat, &sec.cred);
+       /* Restore euid to original value */
+       if ((save_uid != -1) && (setfsuid(save_uid) != uid)) {
+               printerr(0, "WARNING: Failed to restore fsuid"
+                           " to uid %d from %d\n", save_uid, uid);
+       }
+       return retval;
+
+  out_fail:
+       /* Only destroy here if failure.  Otherwise, caller is responsible */
        if (rpc_clnt) clnt_destroy(rpc_clnt);
 
-       return retval;
+       goto out;
 }
 
 
@@ -520,16 +763,20 @@ void
 handle_krb5_upcall(struct clnt_info *clp)
 {
        uid_t                   uid;
-       AUTH                    *auth;
+       CLIENT                  *rpc_clnt = NULL;
+       AUTH                    *auth = NULL;
        struct authgss_private_data pd;
        gss_buffer_desc         token;
        char                    **credlist = NULL;
        char                    **ccname;
+       char                    **dirname;
+       int                     create_resp = -1;
 
        printerr(1, "handling krb5 upcall\n");
 
        token.length = 0;
        token.value = NULL;
+       memset(&pd, 0, sizeof(struct authgss_private_data));
 
        if (read(clp->krb5_fd, &uid, sizeof(uid)) < sizeof(uid)) {
                printerr(0, "WARNING: failed reading uid from krb5 "
@@ -537,61 +784,70 @@ handle_krb5_upcall(struct clnt_info *clp)
                goto out;
        }
 
-       if (uid == 0) {
-               int success = 0;
-
-               /*
-                * Get a list of credential cache names and try each
-                * of them until one works or we've tried them all
-                */
-               if (gssd_get_krb5_machine_cred_list(&credlist)) {
-                       printerr(0, "WARNING: Failed to obtain machine "
-                                   "credentials for connection to "
-                                   "server %s\n", clp->servername);
-                               goto out_return_error;
-               }
-               for (ccname = credlist; ccname && *ccname; ccname++) {
-                       gssd_setup_krb5_machine_gss_ccache(*ccname);
-                       if ((create_auth_rpc_client(clp, &auth, uid,
-                                                   AUTHTYPE_KRB5)) == 0) {
-                               /* Success! */
-                               success++;
+       if (uid != 0 || (uid == 0 && root_uses_machine_creds == 0)) {
+               /* Tell krb5 gss which credentials cache to use */
+               for (dirname = ccachesearch; *dirname != NULL; dirname++) {
+                       if (gssd_setup_krb5_user_gss_ccache(uid, clp->servername, *dirname) == 0)
+                               create_resp = create_auth_rpc_client(clp, &rpc_clnt, &auth, uid,
+                                                            AUTHTYPE_KRB5);
+                       if (create_resp == 0)
                                break;
-                       }
-                       printerr(2, "WARNING: Failed to create krb5 context "
-                                   "for user with uid %d with credentials "
-                                   "cache %s for server %s\n",
-                                uid, *ccname, clp->servername);
-               }
-               gssd_free_krb5_machine_cred_list(credlist);
-               if (!success) {
-                       printerr(0, "WARNING: Failed to create krb5 context "
-                                   "for user with uid %d with any "
-                                   "credentials cache for server %s\n",
-                                uid, clp->servername);
-                       goto out_return_error;
                }
        }
-       else {
-               /* Tell krb5 gss which credentials cache to use */
-               gssd_setup_krb5_user_gss_ccache(uid, clp->servername);
-
-               if (create_auth_rpc_client(clp, &auth, uid, AUTHTYPE_KRB5)) {
-                       printerr(0, "WARNING: Failed to create krb5 context "
-                                   "for user with uid %d for server %s\n",
+       if (create_resp != 0) {
+               if (uid == 0 && root_uses_machine_creds == 1) {
+                       int success = 0;
+
+                       gssd_refresh_krb5_machine_credential(clp->servername,
+                                                            NULL);
+                       /*
+                        * Get a list of credential cache names and try each
+                        * of them until one works or we've tried them all
+                        */
+                       if (gssd_get_krb5_machine_cred_list(&credlist)) {
+                               printerr(0, "ERROR: No credentials found "
+                                        "for connection to server %s\n",
+                                        clp->servername);
+                                       goto out_return_error;
+                       }
+                       for (ccname = credlist; ccname && *ccname; ccname++) {
+                               gssd_setup_krb5_machine_gss_ccache(*ccname);
+                               if ((create_auth_rpc_client(clp, &rpc_clnt,
+                                                           &auth, uid,
+                                                           AUTHTYPE_KRB5)) == 0) {
+                                       /* Success! */
+                                       success++;
+                                       break;
+                               }
+                               printerr(2, "WARNING: Failed to create krb5 context "
+                                        "for user with uid %d with credentials "
+                                        "cache %s for server %s\n",
+                                        uid, *ccname, clp->servername);
+                       }
+                       gssd_free_krb5_machine_cred_list(credlist);
+                       if (!success) {
+                               printerr(1, "WARNING: Failed to create krb5 context "
+                                        "for user with uid %d with any "
+                                        "credentials cache for server %s\n",
+                                        uid, clp->servername);
+                               goto out_return_error;
+                       }
+               } else {
+                       printerr(1, "WARNING: Failed to create krb5 context "
+                                "for user with uid %d for server %s\n",
                                 uid, clp->servername);
                        goto out_return_error;
                }
        }
 
        if (!authgss_get_private_data(auth, &pd)) {
-               printerr(0, "WARNING: Failed to obtain authentication "
+               printerr(1, "WARNING: Failed to obtain authentication "
                            "data for user with uid %d for server %s\n",
                         uid, clp->servername);
                goto out_return_error;
        }
 
-       if (serialize_context_for_kernel(pd.pd_ctx, &token)) {
+       if (serialize_context_for_kernel(pd.pd_ctx, &token, &krb5oid, NULL)) {
                printerr(0, "WARNING: Failed to serialize krb5 context for "
                            "user with uid %d for server %s\n",
                         uid, clp->servername);
@@ -600,14 +856,22 @@ handle_krb5_upcall(struct clnt_info *clp)
 
        do_downcall(clp->krb5_fd, uid, &pd, &token);
 
+out:
        if (token.value)
                free(token.value);
-out:
+#ifndef HAVE_LIBTIRPC
+       if (pd.pd_ctx_hndl.length != 0)
+               authgss_free_private_data(&pd);
+#endif
+       if (auth)
+               AUTH_DESTROY(auth);
+       if (rpc_clnt)
+               clnt_destroy(rpc_clnt);
        return;
 
 out_return_error:
        do_error_downcall(clp->krb5_fd, uid, -1);
-       return;
+       goto out;
 }
 
 /*
@@ -618,7 +882,8 @@ void
 handle_spkm3_upcall(struct clnt_info *clp)
 {
        uid_t                   uid;
-       AUTH                    *auth;
+       CLIENT                  *rpc_clnt = NULL;
+       AUTH                    *auth = NULL;
        struct authgss_private_data pd;
        gss_buffer_desc         token;
 
@@ -633,7 +898,7 @@ handle_spkm3_upcall(struct clnt_info *clp)
                goto out;
        }
 
-       if (create_auth_rpc_client(clp, &auth, uid, AUTHTYPE_SPKM3)) {
+       if (create_auth_rpc_client(clp, &rpc_clnt, &auth, uid, AUTHTYPE_SPKM3)) {
                printerr(0, "WARNING: Failed to create spkm3 context for "
                            "user with uid %d\n", uid);
                goto out_return_error;
@@ -646,7 +911,7 @@ handle_spkm3_upcall(struct clnt_info *clp)
                goto out_return_error;
        }
 
-       if (serialize_context_for_kernel(pd.pd_ctx, &token)) {
+       if (serialize_context_for_kernel(pd.pd_ctx, &token, &spkm3oid, NULL)) {
                printerr(0, "WARNING: Failed to serialize spkm3 context for "
                            "user with uid %d for server\n",
                         uid, clp->servername);
@@ -655,12 +920,16 @@ handle_spkm3_upcall(struct clnt_info *clp)
 
        do_downcall(clp->spkm3_fd, uid, &pd, &token);
 
+out:
        if (token.value)
                free(token.value);
-out:
+       if (auth)
+               AUTH_DESTROY(auth);
+       if (rpc_clnt)
+               clnt_destroy(rpc_clnt);
        return;
 
 out_return_error:
        do_error_downcall(clp->spkm3_fd, uid, -1);
-       return;
+       goto out;
 }