]> git.decadent.org.uk Git - nfs-utils.git/blobdiff - utils/nfsd/nfsd.c
The wrong bit field is being passed to NFSCTL_TCPISSET()
[nfs-utils.git] / utils / nfsd / nfsd.c
index fa6ee71fdd6443245511a4b105df891017106905..aaf8d298e3ad9c589c1d0d0c31df7cd64d36da27 100644 (file)
 #include <getopt.h>
 #include <syslog.h>
 #include <netdb.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+
 #include "nfslib.h"
 
 static void    usage(const char *);
 
 static struct option longopts[] =
 {
+       { "host", 1, 0, 'H' },
        { "help", 0, 0, 'h' },
        { "no-nfs-version", 1, 0, 'N' },
+       { "no-tcp", 0, 0, 'T' },
+       { "no-udp", 0, 0, 'U' },
+       { "port", 1, 0, 'P' },
+       { "port", 1, 0, 'p' },
        { NULL, 0, 0, 0 }
 };
+unsigned int protobits = NFSCTL_ALLBITS;
 unsigned int versbits = NFSCTL_ALLBITS;
+char *haddr = NULL;
 
 int
 main(int argc, char **argv)
 {
        int     count = 1, c, error, port, fd, found_one;
        struct servent *ent;
+       struct hostent *hp;
 
        ent = getservbyname ("nfs", "udp");
        if (ent != NULL)
@@ -44,8 +56,19 @@ main(int argc, char **argv)
        else
                port = 2049;
 
-       while ((c = getopt_long(argc, argv, "hN:p:P:", longopts, NULL)) != EOF) {
+       while ((c = getopt_long(argc, argv, "H:hN:p:P:TU", longopts, NULL)) != EOF) {
                switch(c) {
+               case 'H':
+                       if (inet_addr(optarg) != INADDR_NONE) {
+                               haddr = strdup(optarg);
+                       } else if ((hp = gethostbyname(optarg)) != NULL) {
+                               haddr = inet_ntoa((*(struct in_addr*)(hp->h_addr_list[0])));
+                       } else {
+                               fprintf(stderr, "%s: Unknown hostname: %s\n",
+                                       argv[0], optarg);
+                               usage(argv [0]);
+                       }
+                       break;
                case 'P':       /* XXX for nfs-server compatibility */
                case 'p':
                        port = atoi(optarg);
@@ -67,6 +90,12 @@ main(int argc, char **argv)
                                exit(1);
                        }
                        break;
+               case 'T':
+                               NFSCTL_TCPUNSET(protobits);
+                               break;
+               case 'U':
+                               NFSCTL_UDPUNSET(protobits);
+                               break;
                default:
                        fprintf(stderr, "Invalid argument: '%c'\n", c);
                case 'h':
@@ -76,6 +105,10 @@ main(int argc, char **argv)
        /*
         * Do some sanity checking, if the ctlbits are set
         */
+       if (!NFSCTL_UDPISSET(protobits) && !NFSCTL_TCPISSET(protobits)) {
+               fprintf(stderr, "invalid protocol specified\n");
+               exit(1);
+       }
        found_one = 0;
        for (c = NFSD_MINVERS; c <= NFSD_MAXVERS; c++) {
                if (NFSCTL_VERISSET(versbits, c))
@@ -86,6 +119,15 @@ main(int argc, char **argv)
                exit(1);
        }                       
 
+       if (NFSCTL_VERISSET(versbits, 4) && !NFSCTL_TCPISSET(protobits)) {
+               fprintf(stderr, "version 4 requires the TCP protocol\n");
+               exit(1);
+       }
+       if (haddr == NULL) {
+               struct in_addr in = {INADDR_ANY}; 
+               haddr = strdup(inet_ntoa(in));
+       }
+
        if (chdir(NFS_STATEDIR)) {
                fprintf(stderr, "%s: chdir(%s) failed: %s\n",
                        argv [0], NFS_STATEDIR, strerror(errno));
@@ -116,7 +158,7 @@ main(int argc, char **argv)
        closeall(3);
 
        openlog("nfsd", LOG_PID, LOG_DAEMON);
-       if ((error = nfssvc(port, count, versbits)) < 0) {
+       if ((error = nfssvc(port, count, versbits, protobits, haddr)) < 0) {
                int e = errno;
                syslog(LOG_ERR, "nfssvc: %s", strerror(e));
                closelog();
@@ -129,7 +171,7 @@ static void
 usage(const char *prog)
 {
        fprintf(stderr, "Usage:\n"
-               "%s [-p|-P|--port port] [-N|--no-nfs-version version ] nrservs\n", 
+               "%s [-H hostname] [-p|-P|--port port] [-N|--no-nfs-version version ] [-T|--no-tcp] [-U|--no-udp] nrservs\n", 
                prog);
        exit(2);
 }