]> git.decadent.org.uk Git - nfs-utils.git/blobdiff - utils/gssd/gssd_proc.c
nfs-utils: switch gssd to use standard function for getting an RPC client
[nfs-utils.git] / utils / gssd / gssd_proc.c
index 295c37dfaa5ba515ef5ac9ffd6c5a2b9219050f7..1d55d8d1a58267080c4b160d4b2a763348b3f7c1 100644 (file)
@@ -72,6 +72,7 @@
 #include "gss_util.h"
 #include "krb5_util.h"
 #include "context.h"
+#include "nfsrpc.h"
 
 /*
  * pollarray:
@@ -102,12 +103,68 @@ struct pollfd * pollarray;
 
 int pollsize;  /* the size of pollaray (in pollfd's) */
 
+/*
+ * convert a presentation address string to a sockaddr_storage struct. Returns
+ * true on success and false on failure.
+ */
+static int
+addrstr_to_sockaddr(struct sockaddr *sa, const char *addr, const int port)
+{
+       struct sockaddr_in      *s4 = (struct sockaddr_in *) sa;
+
+       if (inet_pton(AF_INET, addr, &s4->sin_addr)) {
+               s4->sin_family = AF_INET;
+               s4->sin_port = htons(port);
+       } else {
+               printerr(0, "ERROR: unable to convert %s to address\n", addr);
+               return 0;
+       }
+
+       return 1;
+}
+
+/*
+ * convert a sockaddr to a hostname
+ */
+static char *
+sockaddr_to_hostname(const struct sockaddr *sa, const char *addr)
+{
+       socklen_t               addrlen;
+       int                     err;
+       char                    *hostname;
+       char                    hbuf[NI_MAXHOST];
+
+       switch (sa->sa_family) {
+       case AF_INET:
+               addrlen = sizeof(struct sockaddr_in);
+               break;
+       default:
+               printerr(0, "ERROR: unrecognized addr family %d\n",
+                        sa->sa_family);
+               return NULL;
+       }
+
+       err = getnameinfo(sa, addrlen, hbuf, sizeof(hbuf), NULL, 0,
+                         NI_NAMEREQD);
+       if (err) {
+               printerr(0, "ERROR: unable to resolve %s to hostname: %s\n",
+                        addr, err == EAI_SYSTEM ? strerror(err) :
+                                                  gai_strerror(err));
+               return NULL;
+       }
+
+       hostname = strdup(hbuf);
+
+       return hostname;
+}
+
 /* XXX buffer problems: */
 static int
 read_service_info(char *info_file_name, char **servicename, char **servername,
-                 int *prog, int *vers, char **protocol, int *port) {
+                 int *prog, int *vers, char **protocol,
+                 struct sockaddr *addr) {
 #define INFOBUFLEN 256
-       char            buf[INFOBUFLEN];
+       char            buf[INFOBUFLEN + 1];
        static char     dummy[128];
        int             nbytes;
        static char     service[128];
@@ -117,10 +174,9 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
        char            protoname[16];
        char            cb_port[128];
        char            *p;
-       in_addr_t       inaddr;
        int             fd = -1;
-       struct hostent  *ent = NULL;
        int             numfields;
+       int             port = 0;
 
        *servicename = *servername = *protocol = NULL;
 
@@ -132,6 +188,7 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
        if ((nbytes = read(fd, buf, INFOBUFLEN)) == -1)
                goto fail;
        close(fd);
+       buf[nbytes] = '\0';
 
        numfields = sscanf(buf,"RPC server: %127s\n"
                   "service: %127s %15s version %15s\n"
@@ -159,21 +216,26 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
        if((*prog != 100003) || ((*vers != 2) && (*vers != 3) && (*vers != 4)))
                goto fail;
 
-       /* create service name */
-       inaddr = inet_addr(address);
-       if (!(ent = gethostbyaddr(&inaddr, sizeof(inaddr), AF_INET))) {
-               printerr(0, "ERROR: can't resolve server %s name\n", address);
-               goto fail;
+       if (cb_port[0] != '\0') {
+               port = atoi(cb_port);
+               if (port < 0 || port > 65535)
+                       goto fail;
        }
-       if (!(*servername = calloc(strlen(ent->h_name) + 1, 1)))
+
+       if (!addrstr_to_sockaddr(addr, address, port))
+               goto fail;
+
+       *servername = sockaddr_to_hostname(addr, address);
+       if (*servername == NULL)
                goto fail;
-       memcpy(*servername, ent->h_name, strlen(ent->h_name));
-       snprintf(buf, INFOBUFLEN, "%s@%s", service, ent->h_name);
+
+       nbytes = snprintf(buf, INFOBUFLEN, "%s@%s", service, *servername);
+       if (nbytes > INFOBUFLEN)
+               goto fail;
+
        if (!(*servicename = calloc(strlen(buf) + 1, 1)))
                goto fail;
        memcpy(*servicename, buf, strlen(buf));
-       if (cb_port[0] != '\0')
-               *port = atoi(cb_port);
 
        if (!(*protocol = strdup(protoname)))
                goto fail;
@@ -181,9 +243,10 @@ read_service_info(char *info_file_name, char **servicename, char **servername,
 fail:
        printerr(0, "ERROR: failed to read service info\n");
        if (fd != -1) close(fd);
-       if (*servername) free(*servername);
-       if (*servicename) free(*servicename);
-       if (*protocol) free(*protocol);
+       free(*servername);
+       free(*servicename);
+       free(*protocol);
+       *servicename = *servername = *protocol = NULL;
        return -1;
 }
 
@@ -199,10 +262,10 @@ destroy_client(struct clnt_info *clp)
        if (clp->dir_fd != -1) close(clp->dir_fd);
        if (clp->krb5_fd != -1) close(clp->krb5_fd);
        if (clp->spkm3_fd != -1) close(clp->spkm3_fd);
-       if (clp->dirname) free(clp->dirname);
-       if (clp->servicename) free(clp->servicename);
-       if (clp->servername) free(clp->servername);
-       if (clp->protocol) free(clp->protocol);
+       free(clp->dirname);
+       free(clp->servicename);
+       free(clp->servername);
+       free(clp->protocol);
        free(clp);
 }
 
@@ -249,7 +312,7 @@ process_clnt_dir_files(struct clnt_info * clp)
        if ((clp->servicename == NULL) &&
             read_service_info(info_file_name, &clp->servicename,
                                &clp->servername, &clp->prog, &clp->vers,
-                               &clp->protocol, &clp->port))
+                               &clp->protocol, (struct sockaddr *) &clp->addr))
                return -1;
        return 0;
 }
@@ -478,6 +541,74 @@ out_err:
        return -1;
 }
 
+/*
+ * If the port isn't already set, do an rpcbind query to the remote server
+ * using the program and version and get the port. 
+ *
+ * Newer kernels send the value of the port= mount option in the "info"
+ * file for the upcall or '0' for NFSv2/3. For NFSv4 it sends the value
+ * of the port= option or '2049'. The port field in a new sockaddr should
+ * reflect the value that was sent by the kernel.
+ */
+static int
+populate_port(struct sockaddr *sa, const socklen_t salen,
+             const rpcprog_t program, const rpcvers_t version,
+             const unsigned short protocol)
+{
+       struct sockaddr_in      *s4 = (struct sockaddr_in *) sa;
+       unsigned short          port;
+
+       /*
+        * Newer kernels send the port in the upcall. If we already have
+        * the port, there's no need to look it up.
+        */
+       switch (sa->sa_family) {
+       case AF_INET:
+               if (s4->sin_port != 0) {
+                       printerr(2, "DEBUG: port already set to %d\n",
+                                ntohs(s4->sin_port));
+                       return 1;
+               }
+               break;
+       default:
+               printerr(0, "ERROR: unsupported address family %d\n",
+                           sa->sa_family);
+               return 0;
+       }
+
+       /*
+        * Newer kernels that send the port in the upcall set the value to
+        * 2049 for NFSv4 mounts when one isn't specified. The check below is
+        * only for kernels that don't send the port in the upcall. For those
+        * we either have to do an rpcbind query or set it to the standard
+        * port. Doing a query could be problematic (firewalls, etc), so take
+        * the latter approach.
+        */
+       if (program == 100003 && version == 4) {
+               port = 2049;
+               goto set_port;
+       }
+
+       port = nfs_getport(sa, salen, program, version, protocol);
+       if (!port) {
+               printerr(0, "ERROR: unable to obtain port for prog %ld "
+                           "vers %ld\n", program, version);
+               return 0;
+       }
+
+set_port:
+       printerr(2, "DEBUG: setting port to %hu for prog %lu vers %lu\n", port,
+                program, version);
+
+       switch (sa->sa_family) {
+       case AF_INET:
+               s4->sin_port = htons(port);
+               break;
+       }
+
+       return 1;
+}
+
 /*
  * Create an RPC connection and establish an authenticated
  * gss context with a server.
@@ -493,14 +624,12 @@ int create_auth_rpc_client(struct clnt_info *clp,
        AUTH                    *auth = NULL;
        uid_t                   save_uid = -1;
        int                     retval = -1;
-       int                     errcode;
        OM_uint32               min_stat;
        char                    rpc_errmsg[1024];
-       int                     sockp = RPC_ANYSOCK;
-       int                     sendsz = 32768, recvsz = 32768;
-       struct addrinfo         ai_hints, *a = NULL;
-       char                    service[64];
-       char                    *at_sign;
+       int                     protocol;
+       struct timeval          timeout = {5, 0};
+       struct sockaddr         *addr = (struct sockaddr *) &clp->addr;
+       socklen_t               salen;
 
        /* Create the context as the user (not as root) */
        save_uid = geteuid();
@@ -554,15 +683,10 @@ int create_auth_rpc_client(struct clnt_info *clp,
        printerr(2, "creating %s client for server %s\n", clp->protocol,
                        clp->servername);
 
-       memset(&ai_hints, '\0', sizeof(ai_hints));
-       ai_hints.ai_family = PF_INET;
-       ai_hints.ai_flags |= AI_CANONNAME;
        if ((strcmp(clp->protocol, "tcp")) == 0) {
-               ai_hints.ai_socktype = SOCK_STREAM;
-               ai_hints.ai_protocol = IPPROTO_TCP;
+               protocol = IPPROTO_TCP;
        } else if ((strcmp(clp->protocol, "udp")) == 0) {
-               ai_hints.ai_socktype = SOCK_DGRAM;
-               ai_hints.ai_protocol = IPPROTO_UDP;
+               protocol = IPPROTO_UDP;
        } else {
                printerr(0, "WARNING: unrecognized protocol, '%s', requested "
                         "for connection to server %s for user with uid %d\n",
@@ -570,72 +694,31 @@ int create_auth_rpc_client(struct clnt_info *clp,
                goto out_fail;
        }
 
-       /* extract the service name from clp->servicename */
-       if ((at_sign = strchr(clp->servicename, '@')) == NULL) {
-               printerr(0, "WARNING: servicename (%s) not formatted as "
-                       "expected with service@host\n", clp->servicename);
+       switch (addr->sa_family) {
+       case AF_INET:
+               salen = sizeof(struct sockaddr_in);
+               break;
+       default:
+               printerr(1, "ERROR: Unknown address family %d\n",
+                        addr->sa_family);
                goto out_fail;
        }
-       if ((at_sign - clp->servicename) >= sizeof(service)) {
-               printerr(0, "WARNING: service portion of servicename (%s) "
-                       "is too long!\n", clp->servicename);
-               goto out_fail;
-       }
-       strncpy(service, clp->servicename, at_sign - clp->servicename);
-       service[at_sign - clp->servicename] = '\0';
 
-       errcode = getaddrinfo(clp->servername, service, &ai_hints, &a);
-       if (errcode) {
-               printerr(0, "WARNING: Error from getaddrinfo for server "
-                        "'%s': %s\n", clp->servername, gai_strerror(errcode));
+       if (!populate_port(addr, salen, clp->prog, clp->vers, protocol))
                goto out_fail;
-       }
 
-       if (a == NULL) {
-               printerr(0, "WARNING: No address information found for "
-                        "connection to server %s for user with uid %d\n",
+       rpc_clnt = nfs_get_rpcclient(addr, salen, protocol, clp->prog,
+                                    clp->vers, &timeout);
+       if (!rpc_clnt) {
+               snprintf(rpc_errmsg, sizeof(rpc_errmsg),
+                        "WARNING: can't create %s rpc_clnt to server %s for "
+                        "user with uid %d",
+                        protocol == IPPROTO_TCP ? "tcp" : "udp",
                         clp->servername, uid);
+               printerr(0, "%s\n",
+                        clnt_spcreateerror(rpc_errmsg));
                goto out_fail;
        }
-       if (clp->port)
-               ((struct sockaddr_in *)a->ai_addr)->sin_port = htons(clp->port);
-       if (a->ai_protocol == IPPROTO_TCP) {
-               if ((rpc_clnt = clnttcp_create(
-                                       (struct sockaddr_in *) a->ai_addr,
-                                       clp->prog, clp->vers, &sockp,
-                                       sendsz, recvsz)) == NULL) {
-                       snprintf(rpc_errmsg, sizeof(rpc_errmsg),
-                                "WARNING: can't create tcp rpc_clnt "
-                                "for server %s for user with uid %d",
-                                clp->servername, uid);
-                       printerr(0, "%s\n",
-                                clnt_spcreateerror(rpc_errmsg));
-                       goto out_fail;
-               }
-       } else if (a->ai_protocol == IPPROTO_UDP) {
-               const struct timeval timeout = {5, 0};
-               if ((rpc_clnt = clntudp_bufcreate(
-                                       (struct sockaddr_in *) a->ai_addr,
-                                       clp->prog, clp->vers, timeout,
-                                       &sockp, sendsz, recvsz)) == NULL) {
-                       snprintf(rpc_errmsg, sizeof(rpc_errmsg),
-                                "WARNING: can't create udp rpc_clnt "
-                                "for server %s for user with uid %d",
-                                clp->servername, uid);
-                       printerr(0, "%s\n",
-                                clnt_spcreateerror(rpc_errmsg));
-                       goto out_fail;
-               }
-       } else {
-               /* Shouldn't happen! */
-               printerr(0, "ERROR: requested protocol '%s', but "
-                        "got addrinfo with protocol %d\n",
-                        clp->protocol, a->ai_protocol);
-               goto out_fail;
-       }
-       /* We're done with this */
-       freeaddrinfo(a);
-       a = NULL;
 
        printerr(2, "creating context with server %s\n", clp->servicename);
        auth = authgss_create_default(rpc_clnt, clp->servicename, &sec);
@@ -657,7 +740,6 @@ int create_auth_rpc_client(struct clnt_info *clp,
   out:
        if (sec.cred != GSS_C_NO_CREDENTIAL)
                gss_release_cred(&min_stat, &sec.cred);
-       if (a != NULL) freeaddrinfo(a);
        /* Restore euid to original value */
        if ((save_uid != -1) && (setfsuid(save_uid) != uid)) {
                printerr(0, "WARNING: Failed to restore fsuid"