7af0ceacd0c68957fd47a2e095d39d938f13e6f2
[nfs-utils.git] / utils / statd / sm-notify.c
1 /*
2  * Send NSM notify calls to all hosts listed in /var/lib/sm
3  *
4  * Copyright (C) 2004-2006 Olaf Kirch <okir@suse.de>
5  */
6
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <sys/stat.h>
10 #include <sys/poll.h>
11 #include <sys/param.h>
12 #include <sys/syslog.h>
13 #include <arpa/inet.h>
14 #include <dirent.h>
15 #include <time.h>
16 #include <stdio.h>
17 #include <getopt.h>
18 #include <stdlib.h>
19 #include <fcntl.h>
20 #include <unistd.h>
21 #include <string.h>
22 #include <stdarg.h>
23 #include <netdb.h>
24 #include <errno.h>
25
26 #ifndef BASEDIR
27 #define BASEDIR         "/var/lib/nfs"
28 #endif
29
30 #define _SM_STATE_PATH  BASEDIR "/state"
31 #define _SM_DIR_PATH    BASEDIR "/sm"
32 #define _SM_BAK_PATH    _SM_DIR_PATH ".bak"
33
34 #define NSM_PROG        100024
35 #define NSM_PROGRAM     100024
36 #define NSM_VERSION     1
37 #define NSM_TIMEOUT     2
38 #define NSM_NOTIFY      6
39 #define NSM_MAX_TIMEOUT 120     /* don't make this too big */
40 #define MAXMSGSIZE      256
41
42 typedef struct sockaddr_storage nsm_address;
43
44 struct nsm_host {
45         struct nsm_host *       next;
46         char *                  name;
47         char *                  path;
48         nsm_address             addr;
49         time_t                  last_used;
50         time_t                  send_next;
51         unsigned int            timeout;
52         unsigned int            retries;
53         unsigned int            xid;
54 };
55
56 static char             nsm_hostname[256];
57 static uint32_t         nsm_state;
58 static int              opt_debug = 0;
59 static int              opt_quiet = 0;
60 static int              opt_update_state = 1;
61 static unsigned int     opt_max_retry = 15 * 60;
62 static char *           opt_srcaddr = 0;
63 static uint16_t         opt_srcport = 0;
64 static int              log_syslog = 0;
65
66 static unsigned int     nsm_get_state(int);
67 static void             notify(void);
68 static void             notify_host(int, struct nsm_host *);
69 static void             recv_reply(int);
70 static void             backup_hosts(const char *, const char *);
71 static void             get_hosts(const char *);
72 static void             insert_host(struct nsm_host *);
73 struct nsm_host *       find_host(uint32_t);
74 static int              addr_parse(int, const char *, nsm_address *);
75 static int              addr_get_port(nsm_address *);
76 static void             addr_set_port(nsm_address *, int);
77 static int              host_lookup(int, const char *, nsm_address *);
78 void                    nsm_log(int fac, const char *fmt, ...);
79
80 static struct nsm_host *        hosts = NULL;
81
82 int
83 main(int argc, char **argv)
84 {
85         int     c;
86
87         while ((c = getopt(argc, argv, "dm:np:v:q")) != -1) {
88                 switch (c) {
89                 case 'd':
90                         opt_debug++;
91                         break;
92                 case 'm':
93                         opt_max_retry = atoi(optarg) * 60;
94                         break;
95                 case 'n':
96                         opt_update_state = 0;
97                         break;
98                 case 'p':
99                         opt_srcport = atoi(optarg);
100                         break;
101                 case 'v':
102                         opt_srcaddr = optarg;
103                         break;
104                 case 'q':
105                         opt_quiet = 1;
106                         break;
107                 default:
108                         goto usage;
109                 }
110         }
111
112         if (optind < argc) {
113 usage:          fprintf(stderr, "sm-notify [-d]\n");
114                 return 1;
115         }
116
117         if (opt_srcaddr) {
118                 strncpy(nsm_hostname, opt_srcaddr, sizeof(nsm_hostname)-1);
119         } else
120         if (gethostname(nsm_hostname, sizeof(nsm_hostname)) < 0) {
121                 perror("gethostname");
122                 return 1;
123         }
124
125         backup_hosts(_SM_DIR_PATH, _SM_BAK_PATH);
126         get_hosts(_SM_BAK_PATH);
127
128         if (!opt_debug) {
129                 if (!opt_quiet)
130                         printf("Backgrounding to notify hosts...\n");
131
132                 openlog("sm-notify", LOG_PID, LOG_DAEMON);
133                 log_syslog = 1;
134
135                 if (daemon(0, 0) < 0) {
136                         nsm_log(LOG_WARNING, "unable to background: %s",
137                                         strerror(errno));
138                         return 1;
139                 }
140
141                 close(0);
142                 close(1);
143                 close(2);
144         }
145
146         /* Get and update the NSM state. This will call sync() */
147         nsm_state = nsm_get_state(opt_update_state);
148
149         notify();
150
151         if (hosts) {
152                 struct nsm_host *hp;
153
154                 while ((hp = hosts) != 0) {
155                         hosts = hp->next;
156                         nsm_log(LOG_NOTICE,
157                                 "Unable to notify %s, giving up",
158                                 hp->name);
159                 }
160                 return 1;
161         }
162
163         return 0;
164 }
165
166 /*
167  * Notify hosts
168  */
169 void
170 notify(void)
171 {
172         nsm_address local_addr;
173         time_t  failtime = 0;
174         int     sock = -1;
175
176         sock = socket(AF_INET, SOCK_DGRAM, 0);
177         if (sock < 0) {
178                 perror("socket");
179                 exit(1);
180         }
181         fcntl(sock, F_SETFL, O_NONBLOCK);
182
183         memset(&local_addr, 0, sizeof(local_addr));
184         local_addr.ss_family = AF_INET; /* Default to IPv4 */
185
186         /* Bind source IP if provided on command line */
187         if (opt_srcaddr) {
188                 if (!addr_parse(AF_INET, opt_srcaddr, &local_addr)
189                  && !host_lookup(AF_INET, opt_srcaddr, &local_addr)) {
190                         nsm_log(LOG_WARNING,
191                                 "Not a valid hostname or address: \"%s\"\n",
192                                 opt_srcaddr);
193                         exit(1);
194                 }
195                 /* We know it's IPv4 at this point */
196         }
197
198         /* Use source port if provided on the command line,
199          * otherwise use bindresvport */
200         if (opt_srcport) {
201                 addr_set_port(&local_addr, opt_srcport);
202                 if (bind(sock, (struct sockaddr *) &local_addr, sizeof(local_addr)) < 0) {
203                         perror("bind");
204                         exit(1);
205                 }
206         } else {
207                 (void) bindresvport(sock, (struct sockaddr_in *) &local_addr);
208         }
209
210         if (opt_max_retry)
211                 failtime = time(NULL) + opt_max_retry;
212
213         while (hosts) {
214                 struct pollfd   pfd;
215                 time_t          now = time(NULL);
216                 unsigned int    sent = 0;
217                 struct nsm_host *hp;
218                 long            wait;
219
220                 if (failtime && now >= failtime)
221                         break;
222
223                 while ((wait = hosts->send_next - now) <= 0) {
224                         /* Never send more than 10 packets at once */
225                         if (sent++ >= 10)
226                                 break;
227
228                         /* Remove queue head */
229                         hp = hosts;
230                         hosts = hp->next;
231
232                         notify_host(sock, hp);
233
234                         /* Set the timeout for this call, using an
235                            exponential timeout strategy */
236                         wait = hp->timeout;
237                         if ((hp->timeout <<= 1) > NSM_MAX_TIMEOUT)
238                                 hp->timeout = NSM_MAX_TIMEOUT;
239                         hp->send_next = now + wait;
240                         hp->retries++;
241
242                         insert_host(hp);
243                 }
244
245                 nsm_log(LOG_DEBUG, "Host %s due in %ld seconds",
246                                 hosts->name, wait);
247
248                 pfd.fd = sock;
249                 pfd.events = POLLIN;
250
251                 wait *= 1000;
252                 if (wait < 100)
253                         wait = 100;
254                 if (poll(&pfd, 1, wait) != 1)
255                         continue;
256
257                 recv_reply(sock);
258         }
259 }
260
261 /*
262  * Send notification to a single host
263  */
264 void
265 notify_host(int sock, struct nsm_host *host)
266 {
267         static unsigned int     xid = 0;
268         nsm_address             dest;
269         uint32_t                msgbuf[MAXMSGSIZE], *p;
270         unsigned int            len;
271
272         if (!xid)
273                 xid = getpid() + time(NULL);
274         if (!host->xid)
275                 host->xid = xid++;
276
277         memset(msgbuf, 0, sizeof(msgbuf));
278         p = msgbuf;
279         *p++ = htonl(host->xid);
280         *p++ = 0;
281         *p++ = htonl(2);
282
283         /* If we retransmitted 4 times, reset the port to force
284          * a new portmap lookup (in case statd was restarted)
285          */
286         if (host->retries >= 4) {
287                 addr_set_port(&host->addr, 0);
288                 host->retries = 0;
289         }
290
291         dest = host->addr;
292         if (addr_get_port(&dest) == 0) {
293                 /* Build a PMAP packet */
294                 nsm_log(LOG_DEBUG, "Sending portmap query to %s", host->name);
295
296                 addr_set_port(&dest, 111);
297                 *p++ = htonl(100000);
298                 *p++ = htonl(2);
299                 *p++ = htonl(3);
300
301                 /* Auth and verf */
302                 *p++ = 0; *p++ = 0;
303                 *p++ = 0; *p++ = 0;
304
305                 *p++ = htonl(NSM_PROGRAM);
306                 *p++ = htonl(NSM_VERSION);
307                 *p++ = htonl(IPPROTO_UDP);
308                 *p++ = 0;
309         } else {
310                 /* Build an SM_NOTIFY packet */
311                 nsm_log(LOG_DEBUG, "Sending SM_NOTIFY to %s", host->name);
312
313                 *p++ = htonl(NSM_PROGRAM);
314                 *p++ = htonl(NSM_VERSION);
315                 *p++ = htonl(NSM_NOTIFY);
316
317                 /* Auth and verf */
318                 *p++ = 0; *p++ = 0;
319                 *p++ = 0; *p++ = 0;
320
321                 /* state change */
322                 len = strlen(nsm_hostname);
323                 *p++ = htonl(len);
324                 memcpy(p, nsm_hostname, len);
325                 p += (len + 3) >> 2;
326                 *p++ = htonl(nsm_state);
327         }
328         len = (p - msgbuf) << 2;
329
330         sendto(sock, msgbuf, len, 0, (struct sockaddr *) &dest, sizeof(dest));
331 }
332
333 /*
334  * Receive reply from remote host
335  */
336 void
337 recv_reply(int sock)
338 {
339         struct nsm_host *hp;
340         uint32_t        msgbuf[MAXMSGSIZE], *p, *end;
341         uint32_t        xid;
342         int             res;
343
344         res = recv(sock, msgbuf, sizeof(msgbuf), 0);
345         if (res < 0)
346                 return;
347
348         nsm_log(LOG_DEBUG, "Received packet...");
349
350         p = msgbuf;
351         end = p + (res >> 2);
352
353         xid = ntohl(*p++);
354         if (*p++ != htonl(1)    /* must be REPLY */
355          || *p++ != htonl(0)    /* must be ACCEPTED */
356          || *p++ != htonl(0)    /* must be NULL verifier */
357          || *p++ != htonl(0)
358          || *p++ != htonl(0))   /* must be SUCCESS */
359                 return;
360
361         /* Before we look at the data, find the host struct for
362            this reply */
363         if ((hp = find_host(xid)) == NULL)
364                 return;
365
366         if (addr_get_port(&hp->addr) == 0) {
367                 /* This was a portmap request */
368                 unsigned int    port;
369
370                 port = ntohl(*p++);
371                 if (p > end)
372                         goto fail;
373
374                 hp->send_next = time(NULL);
375                 if (port == 0) {
376                         /* No binding for statd. Delay the next
377                          * portmap query for max timeout */
378                         nsm_log(LOG_DEBUG, "No statd on %s", hp->name);
379                         hp->timeout = NSM_MAX_TIMEOUT;
380                         hp->send_next += NSM_MAX_TIMEOUT;
381                 } else {
382                         addr_set_port(&hp->addr, port);
383                         if (hp->timeout >= NSM_MAX_TIMEOUT / 4)
384                                 hp->timeout = NSM_MAX_TIMEOUT / 4;
385                 }
386                 hp->xid = 0;
387         } else {
388                 /* Successful NOTIFY call. Server returns void,
389                  * so nothing we need to do here (except
390                  * check that we didn't read past the end of the
391                  * packet)
392                  */
393                 if (p <= end) {
394                         nsm_log(LOG_DEBUG, "Host %s notified successfully", hp->name);
395                         unlink(hp->path);
396                         free(hp->name);
397                         free(hp->path);
398                         free(hp);
399                         return;
400                 }
401         }
402
403 fail:   /* Re-insert the host */
404         insert_host(hp);
405 }
406
407 /*
408  * Back up all hosts from the sm directory to sm.bak
409  */
410 static void
411 backup_hosts(const char *dirname, const char *bakname)
412 {
413         struct dirent   *de;
414         DIR             *dir;
415
416         if (!(dir = opendir(dirname))) {
417                 perror(dirname);
418                 return;
419         }
420
421         while ((de = readdir(dir)) != NULL) {
422                 char    src[1024], dst[1024];
423
424                 if (de->d_name[0] == '.')
425                         continue;
426
427                 snprintf(src, sizeof(src), "%s/%s", dirname, de->d_name);
428                 snprintf(dst, sizeof(dst), "%s/%s", bakname, de->d_name);
429                 if (rename(src, dst) < 0) {
430                         nsm_log(LOG_WARNING,
431                                 "Failed to rename %s -> %s: %m",
432                                 src, dst);
433                 }
434         }
435         closedir(dir);
436 }
437
438 /*
439  * Get all entries from sm.bak and convert them to host names
440  */
441 static void
442 get_hosts(const char *dirname)
443 {
444         struct nsm_host *host;
445         struct dirent   *de;
446         DIR             *dir;
447
448         if (!(dir = opendir(dirname))) {
449                 perror(dirname);
450                 return;
451         }
452
453         host = NULL;
454         while ((de = readdir(dir)) != NULL) {
455                 struct stat     stb;
456                 char            path[1024];
457
458                 if (de->d_name[0] == '.')
459                         continue;
460                 if (host == NULL)
461                         host = calloc(1, sizeof(*host));
462
463                 snprintf(path, sizeof(path), "%s/%s", dirname, de->d_name);
464                 if (!addr_parse(AF_INET, de->d_name, &host->addr)
465                  && !addr_parse(AF_INET6, de->d_name, &host->addr)
466                  && !host_lookup(AF_INET, de->d_name, &host->addr)) {
467                         nsm_log(LOG_WARNING,
468                                 "%s doesn't seem to be a valid address, skipped",
469                                 de->d_name);
470                         unlink(path);
471                         continue;
472                 }
473
474                 if (stat(path, &stb) < 0)
475                         continue;
476                 host->last_used = stb.st_mtime;
477                 host->timeout = NSM_TIMEOUT;
478                 host->path = strdup(path);
479                 host->name = strdup(de->d_name);
480
481                 insert_host(host);
482                 host = NULL;
483         }
484         closedir(dir);
485
486         if (host)
487                 free(host);
488 }
489
490 /*
491  * Insert host into sorted list
492  */
493 void
494 insert_host(struct nsm_host *host)
495 {
496         struct nsm_host **where, *p;
497
498         where = &hosts;
499         while ((p = *where) != 0) {
500                 /* Sort in ascending order of timeout */
501                 if (host->send_next < p->send_next)
502                         break;
503                 /* If we have the same timeout, put the
504                  * most recently used host first.
505                  * This makes sure that "recent" hosts
506                  * get notified first.
507                  */
508                 if (host->send_next == p->send_next
509                  && host->last_used > p->last_used)
510                         break;
511                 where = &p->next;
512         }
513
514         host->next = *where;
515         *where = host;
516 }
517
518 /*
519  * Find host given the XID
520  */
521 struct nsm_host *
522 find_host(uint32_t xid)
523 {
524         struct nsm_host **where, *p;
525
526         where = &hosts;
527         while ((p = *where) != 0) {
528                 if (p->xid == xid) {
529                         *where = p->next;
530                         return p;
531                 }
532                 where = &p->next;
533         }
534         return NULL;
535 }
536
537
538 /*
539  * Retrieve the current NSM state
540  */
541 unsigned int
542 nsm_get_state(int update)
543 {
544         char            newfile[PATH_MAX];
545         int             fd, state;
546
547         if ((fd = open(_SM_STATE_PATH, O_RDONLY)) < 0) {
548                 if (!opt_quiet) {
549                         nsm_log(LOG_WARNING, "%s: %m", _SM_STATE_PATH);
550                         nsm_log(LOG_WARNING, "Creating %s, set initial state 1",
551                                 _SM_STATE_PATH);
552                 }
553                 state = 1;
554                 update = 1;
555         } else {
556                 if (read(fd, &state, sizeof(state)) != sizeof(state)) {
557                         nsm_log(LOG_WARNING,
558                                 "%s: bad file size, setting state = 1",
559                                 _SM_STATE_PATH);
560                         state = 1;
561                         update = 1;
562                 } else {
563                         if (!(state & 1))
564                                 state += 1;
565                 }
566                 close(fd);
567         }
568
569         if (update) {
570                 state += 2;
571                 snprintf(newfile, sizeof(newfile),
572                                 "%s.new", _SM_STATE_PATH);
573                 if ((fd = open(newfile, O_CREAT|O_WRONLY, 0644)) < 0) {
574                         nsm_log(LOG_WARNING, "Cannot create %s: %m", newfile);
575                         exit(1);
576                 }
577                 if (write(fd, &state, sizeof(state)) != sizeof(state)) {
578                         nsm_log(LOG_WARNING,
579                                 "Failed to write state to %s", newfile);
580                         exit(1);
581                 }
582                 close(fd);
583                 if (rename(newfile, _SM_STATE_PATH) < 0) {
584                         nsm_log(LOG_WARNING,
585                                 "Cannot create %s: %m", _SM_STATE_PATH);
586                         exit(1);
587                 }
588                 sync();
589         }
590
591         return state;
592 }
593
594 /*
595  * Address handling utilities
596  */
597 static int
598 addr_parse(int af, const char *name, nsm_address *addr)
599 {
600         void    *ptr;
601
602         if (af == AF_INET)
603                 ptr = &((struct sockaddr_in *) addr)->sin_addr;
604         else if (af == AF_INET6)
605                 ptr = &((struct sockaddr_in6 *) addr)->sin6_addr;
606         else
607                 return 0;
608         if (inet_pton(af, name, ptr) <= 0)
609                 return 0;
610         ((struct sockaddr *) addr)->sa_family = af;
611         return 1;
612 }
613
614 int
615 addr_get_port(nsm_address *addr)
616 {
617         switch (((struct sockaddr *) addr)->sa_family) {
618         case AF_INET:
619                 return ntohs(((struct sockaddr_in *) addr)->sin_port);
620         case AF_INET6:
621                 return ntohs(((struct sockaddr_in6 *) addr)->sin6_port);
622         }
623         return 0;
624 }
625
626 static void
627 addr_set_port(nsm_address *addr, int port)
628 {
629         switch (((struct sockaddr *) addr)->sa_family) {
630         case AF_INET:
631                 ((struct sockaddr_in *) addr)->sin_port = htons(port);
632                 break;
633         case AF_INET6:
634                 ((struct sockaddr_in6 *) addr)->sin6_port = htons(port);
635         }
636 }
637
638 static int
639 host_lookup(int af, const char *name, nsm_address *addr)
640 {
641         struct addrinfo hints, *ai;
642         int okay = 0;
643
644         memset(&hints, 0, sizeof(hints));
645         hints.ai_family = af;
646
647         if (getaddrinfo(name, NULL, &hints, &ai) != 0)
648                 return 0;
649
650         if (ai->ai_addrlen < sizeof(*addr)) {
651                 memcpy(addr, ai->ai_addr, ai->ai_addrlen);
652                 okay = 1;
653         }
654
655         freeaddrinfo(ai);
656         return okay;
657 }
658
659 /*
660  * Log a message
661  */
662 void
663 nsm_log(int fac, const char *fmt, ...)
664 {
665         va_list ap;
666
667         if (fac == LOG_DEBUG && !opt_debug)
668                 return;
669
670         va_start(ap, fmt);
671         if (log_syslog)
672                 vsyslog(fac, fmt, ap);
673         else {
674                 vfprintf(stderr, fmt, ap);
675                 fputs("\n", stderr);
676         }
677         va_end(ap);
678 }