]> git.decadent.org.uk Git - nfs-utils.git/blobdiff - utils/mount/network.c
mount.nfs: Always close the socket at the end of getport()
[nfs-utils.git] / utils / mount / network.c
index c092571044b0c31d61cec3f7862c16a93c8f8650..132ff1e763fc4980e6498427f29bc6f5a26cedfb 100644 (file)
@@ -34,7 +34,6 @@
 #include <rpc/pmap_clnt.h>
 #include <sys/socket.h>
 
-#include "conn.h"
 #include "xcommon.h"
 #include "mount.h"
 #include "nls.h"
 #define NFS_PORT 2049
 #endif
 
+#define PMAP_TIMEOUT   (10)
+#define CONNECT_TIMEOUT        (20)
+#define MOUNT_TIMEOUT  (30)
+
+#if SIZEOF_SOCKLEN_T - 0 == 0
+#define socklen_t unsigned int
+#endif
+
 extern int nfs_mount_data_version;
 extern char *progname;
 extern int verbose;
@@ -154,6 +161,134 @@ int nfs_gethostbyname(const char *hostname, struct sockaddr_in *saddr)
        return 1;
 }
 
+/*
+ * Attempt to connect a socket, but time out after "timeout" seconds.
+ *
+ * On error return, caller closes the socket.
+ */
+static int connect_to(int fd, struct sockaddr *addr,
+                       socklen_t addrlen, int timeout)
+{
+       int ret, saved;
+       fd_set rset, wset;
+       struct timeval tv = {
+               .tv_sec = timeout,
+       };
+
+       saved = fcntl(fd, F_GETFL, 0);
+       fcntl(fd, F_SETFL, saved | O_NONBLOCK);
+
+       ret = connect(fd, addr, addrlen);
+       if (ret < 0 && errno != EINPROGRESS)
+               return -1;
+       if (ret == 0)
+               goto out;
+
+       FD_ZERO(&rset);
+       FD_SET(fd, &rset);
+       wset = rset;
+       ret = select(fd + 1, &rset, &wset, NULL, &tv);
+       if (ret == 0) {
+               errno = ETIMEDOUT;
+               return -1;
+       }
+       if (FD_ISSET(fd, &rset) || FD_ISSET(fd, &wset)) {
+               int error;
+               socklen_t len = sizeof(error);
+               if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &len) < 0)
+                       return -1;
+               if (error) {
+                       errno = error;
+                       return -1;
+               }
+       } else
+               return -1;
+
+out:
+       fcntl(fd, F_SETFL, saved);
+       return 0;
+}
+
+/*
+ * Create a socket that is locally bound to a reserved or non-reserved port.
+ *
+ * The caller should check rpc_createerr to determine the cause of any error.
+ */
+static int get_socket(struct sockaddr_in *saddr, unsigned int p_prot,
+                       unsigned int timeout, int resvp, int conn)
+{
+       int so, cc, type;
+       struct sockaddr_in laddr;
+       socklen_t namelen = sizeof(laddr);
+
+       type = (p_prot == IPPROTO_UDP ? SOCK_DGRAM : SOCK_STREAM);
+       if ((so = socket (AF_INET, type, p_prot)) < 0)
+               goto err_socket;
+
+       laddr.sin_family = AF_INET;
+       laddr.sin_port = 0;
+       laddr.sin_addr.s_addr = htonl(INADDR_ANY);
+       if (resvp) {
+               if (bindresvport(so, &laddr) < 0)
+                       goto err_bindresvport;
+       } else {
+               cc = bind(so, (struct sockaddr *)&laddr, namelen);
+               if (cc < 0)
+                       goto err_bind;
+       }
+       if (type == SOCK_STREAM || (conn && type == SOCK_DGRAM)) {
+               cc = connect_to(so, (struct sockaddr *)saddr, namelen,
+                               timeout);
+               if (cc < 0)
+                       goto err_connect;
+       }
+       return so;
+
+err_socket:
+       rpc_createerr.cf_stat = RPC_SYSTEMERROR;
+       rpc_createerr.cf_error.re_errno = errno;
+       if (verbose) {
+               nfs_error(_("%s: Unable to create %s socket: errno %d (%s)\n"),
+                       progname, p_prot == IPPROTO_UDP ? _("UDP") : _("TCP"),
+                       errno, strerror(errno));
+       }
+       return RPC_ANYSOCK;
+
+err_bindresvport:
+       rpc_createerr.cf_stat = RPC_SYSTEMERROR;
+       rpc_createerr.cf_error.re_errno = errno;
+       if (verbose) {
+               nfs_error(_("%s: Unable to bindresvport %s socket: errno %d"
+                               " (%s)\n"),
+                       progname, p_prot == IPPROTO_UDP ? _("UDP") : _("TCP"),
+                       errno, strerror(errno));
+       }
+       close(so);
+       return RPC_ANYSOCK;
+
+err_bind:
+       rpc_createerr.cf_stat = RPC_SYSTEMERROR;
+       rpc_createerr.cf_error.re_errno = errno;
+       if (verbose) {
+               nfs_error(_("%s: Unable to bind to %s socket: errno %d (%s)\n"),
+                       progname, p_prot == IPPROTO_UDP ? _("UDP") : _("TCP"),
+                       errno, strerror(errno));
+       }
+       close(so);
+       return RPC_ANYSOCK;
+
+err_connect:
+       rpc_createerr.cf_stat = RPC_SYSTEMERROR;
+       rpc_createerr.cf_error.re_errno = errno;
+       if (verbose) {
+               nfs_error(_("%s: Unable to connect to %s:%d, errno %d (%s)\n"),
+                       progname, inet_ntoa(saddr->sin_addr),
+                       ntohs(saddr->sin_port), errno, strerror(errno));
+       }
+       close(so);
+       return RPC_ANYSOCK;
+}
+
 /*
  * getport() is very similar to pmap_getport() with the exception that
  * this version tries to use an ephemeral port, since reserved ports are
@@ -180,7 +315,7 @@ static unsigned short getport(struct sockaddr_in *saddr,
         * clnt*create() will create one anyway if this
         * fails.
         */
-       socket = get_socket(saddr, proto, FALSE, FALSE);
+       socket = get_socket(saddr, proto, PMAP_TIMEOUT, FALSE, FALSE);
        if (socket == RPC_ANYSOCK) {
                if (proto == IPPROTO_TCP && errno == ETIMEDOUT) {
                        /*
@@ -225,8 +360,7 @@ static unsigned short getport(struct sockaddr_in *saddr,
                else if (port == 0)
                        rpc_createerr.cf_stat = RPC_PROGNOTREGISTERED;
        }
-       if (socket != 1)
-               close(socket);
+       close(socket);
 
        return port;
 }
@@ -251,7 +385,6 @@ static int probe_port(clnt_addr_t *server, const unsigned long *versions,
        p_vers = vers ? &vers : versions;
        rpc_createerr.cf_stat = 0;
        for (;;) {
-               saddr->sin_port = htons(PMAPPORT);
                p_port = getport(saddr, prog, *p_vers, *p_prot);
                if (p_port) {
                        if (!port || port == p_port) {
@@ -471,7 +604,8 @@ CLIENT *mnt_openclnt(clnt_addr_t *mnt_server, int *msock)
        CLIENT *clnt = NULL;
 
        mnt_saddr->sin_port = htons((u_short)mnt_pmap->pm_port);
-       *msock = get_socket(mnt_saddr, mnt_pmap->pm_prot, TRUE, FALSE);
+       *msock = get_socket(mnt_saddr, mnt_pmap->pm_prot, MOUNT_TIMEOUT,
+                               TRUE, FALSE);
        if (*msock == RPC_ANYSOCK) {
                if (rpc_createerr.cf_error.re_errno == EADDRINUSE)
                        /*
@@ -510,3 +644,78 @@ void mnt_closeclnt(CLIENT *clnt, int msock)
        clnt_destroy(clnt);
        close(msock);
 }
+
+/*
+ * Sigh... getport() doesn't actually check the version number.
+ * In order to make sure that the server actually supports the service
+ * we're requesting, we open and RPC client, and fire off a NULL
+ * RPC call.
+ */
+int clnt_ping(struct sockaddr_in *saddr, const unsigned long prog,
+               const unsigned long vers, const unsigned int prot,
+               struct sockaddr_in *caddr)
+{
+       CLIENT *clnt = NULL;
+       int sock, stat;
+       static char clnt_res;
+       struct sockaddr dissolve;
+
+       rpc_createerr.cf_stat = stat = errno = 0;
+       sock = get_socket(saddr, prot, CONNECT_TIMEOUT, FALSE, TRUE);
+       if (sock == RPC_ANYSOCK) {
+               if (errno == ETIMEDOUT) {
+                       /*
+                        * TCP timeout. Bubble up the error to see 
+                        * how it should be handled.
+                        */
+                       rpc_createerr.cf_stat = RPC_TIMEDOUT;
+               }
+               return 0;
+       }
+
+       if (caddr) {
+               /* Get the address of our end of this connection */
+               socklen_t len = sizeof(*caddr);
+               if (getsockname(sock, caddr, &len) != 0)
+                       caddr->sin_family = 0;
+       }
+
+       switch(prot) {
+       case IPPROTO_UDP:
+               /* The socket is connected (so we could getsockname successfully),
+                * but some servers on multi-homed hosts reply from
+                * the wrong address, so if we stay connected, we lose the reply.
+                */
+               dissolve.sa_family = AF_UNSPEC;
+               connect(sock, &dissolve, sizeof(dissolve));
+
+               clnt = clntudp_bufcreate(saddr, prog, vers,
+                                        RETRY_TIMEOUT, &sock,
+                                        RPCSMALLMSGSIZE, RPCSMALLMSGSIZE);
+               break;
+       case IPPROTO_TCP:
+               clnt = clnttcp_create(saddr, prog, vers, &sock,
+                                     RPCSMALLMSGSIZE, RPCSMALLMSGSIZE);
+               break;
+       }
+       if (!clnt) {
+               close(sock);
+               return 0;
+       }
+       memset(&clnt_res, 0, sizeof(clnt_res));
+       stat = clnt_call(clnt, NULLPROC,
+                        (xdrproc_t)xdr_void, (caddr_t)NULL,
+                        (xdrproc_t)xdr_void, (caddr_t)&clnt_res,
+                        TIMEOUT);
+       if (stat) {
+               clnt_geterr(clnt, &rpc_createerr.cf_error);
+               rpc_createerr.cf_stat = stat;
+       }
+       clnt_destroy(clnt);
+       close(sock);
+
+       if (stat == RPC_SUCCESS)
+               return 1;
+       else
+               return 0;
+}