]> git.decadent.org.uk Git - nfs-utils.git/blob - utils/gssd/gssd_proc.c
gssd: add upcall support for callback authentication
[nfs-utils.git] / utils / gssd / gssd_proc.c
1 /*
2   gssd_proc.c
3
4   Copyright (c) 2000-2004 The Regents of the University of Michigan.
5   All rights reserved.
6
7   Copyright (c) 2000 Dug Song <dugsong@UMICH.EDU>.
8   Copyright (c) 2001 Andy Adamson <andros@UMICH.EDU>.
9   Copyright (c) 2002 Marius Aamodt Eriksen <marius@UMICH.EDU>.
10   Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
11   Copyright (c) 2004 Kevin Coffman <kwc@umich.edu>
12   All rights reserved, all wrongs reversed.
13
14   Redistribution and use in source and binary forms, with or without
15   modification, are permitted provided that the following conditions
16   are met:
17
18   1. Redistributions of source code must retain the above copyright
19      notice, this list of conditions and the following disclaimer.
20   2. Redistributions in binary form must reproduce the above copyright
21      notice, this list of conditions and the following disclaimer in the
22      documentation and/or other materials provided with the distribution.
23   3. Neither the name of the University nor the names of its
24      contributors may be used to endorse or promote products derived
25      from this software without specific prior written permission.
26
27   THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
28   WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
29   MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30   DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
31   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
32   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
33   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
34   BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
35   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
36   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
37   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38
39 */
40
41 #ifdef HAVE_CONFIG_H
42 #include <config.h>
43 #endif  /* HAVE_CONFIG_H */
44
45 #ifndef _GNU_SOURCE
46 #define _GNU_SOURCE
47 #endif
48
49 #include <sys/param.h>
50 #include <rpc/rpc.h>
51 #include <sys/stat.h>
52 #include <sys/socket.h>
53 #include <arpa/inet.h>
54 #include <sys/fsuid.h>
55
56 #include <stdio.h>
57 #include <stdlib.h>
58 #include <pwd.h>
59 #include <grp.h>
60 #include <string.h>
61 #include <dirent.h>
62 #include <poll.h>
63 #include <fcntl.h>
64 #include <signal.h>
65 #include <unistd.h>
66 #include <errno.h>
67 #include <gssapi/gssapi.h>
68 #include <netdb.h>
69
70 #include "gssd.h"
71 #include "err_util.h"
72 #include "gss_util.h"
73 #include "krb5_util.h"
74 #include "context.h"
75 #include "nfsrpc.h"
76
77 /*
78  * pollarray:
79  *      array of struct pollfd suitable to pass to poll. initialized to
80  *      zero - a zero struct is ignored by poll() because the events mask is 0.
81  *
82  * clnt_list:
83  *      linked list of struct clnt_info which associates a clntXXX directory
84  *      with an index into pollarray[], and other basic data about that client.
85  *
86  * Directory structure: created by the kernel
87  *      {rpc_pipefs}/{dir}/clntXX         : one per rpc_clnt struct in the kernel
88  *      {rpc_pipefs}/{dir}/clntXX/krb5    : read uid for which kernel wants
89  *                                          a context, write the resulting context
90  *      {rpc_pipefs}/{dir}/clntXX/info    : stores info such as server name
91  *
92  * Algorithm:
93  *      Poll all {rpc_pipefs}/{dir}/clntXX/krb5 files.  When data is ready,
94  *      read and process; performs rpcsec_gss context initialization protocol to
95  *      get a cred for that user.  Writes result to corresponding krb5 file
96  *      in a form the kernel code will understand.
97  *      In addition, we make sure we are notified whenever anything is
98  *      created or destroyed in {rpc_pipefs} or in any of the clntXX directories,
99  *      and rescan the whole {rpc_pipefs} when this happens.
100  */
101
102 struct pollfd * pollarray;
103
104 int pollsize;  /* the size of pollaray (in pollfd's) */
105
106 /*
107  * convert a presentation address string to a sockaddr_storage struct. Returns
108  * true on success and false on failure.
109  *
110  * Note that we do not populate the sin6_scope_id field here for IPv6 addrs.
111  * gssd nececessarily relies on hostname resolution and DNS AAAA records
112  * do not generally contain scope-id's. This means that GSSAPI auth really
113  * can't work with IPv6 link-local addresses.
114  *
115  * We *could* consider changing this if we did something like adopt the
116  * Microsoft "standard" of using the ipv6-literal.net domainname, but it's
117  * not really feasible at present.
118  */
119 static int
120 addrstr_to_sockaddr(struct sockaddr *sa, const char *addr, const int port)
121 {
122         struct sockaddr_in      *s4 = (struct sockaddr_in *) sa;
123 #ifdef IPV6_SUPPORTED
124         struct sockaddr_in6     *s6 = (struct sockaddr_in6 *) sa;
125 #endif /* IPV6_SUPPORTED */
126
127         if (inet_pton(AF_INET, addr, &s4->sin_addr)) {
128                 s4->sin_family = AF_INET;
129                 s4->sin_port = htons(port);
130 #ifdef IPV6_SUPPORTED
131         } else if (inet_pton(AF_INET6, addr, &s6->sin6_addr)) {
132                 s6->sin6_family = AF_INET6;
133                 s6->sin6_port = htons(port);
134 #endif /* IPV6_SUPPORTED */
135         } else {
136                 printerr(0, "ERROR: unable to convert %s to address\n", addr);
137                 return 0;
138         }
139
140         return 1;
141 }
142
143 /*
144  * convert a sockaddr to a hostname
145  */
146 static char *
147 sockaddr_to_hostname(const struct sockaddr *sa, const char *addr)
148 {
149         socklen_t               addrlen;
150         int                     err;
151         char                    *hostname;
152         char                    hbuf[NI_MAXHOST];
153
154         switch (sa->sa_family) {
155         case AF_INET:
156                 addrlen = sizeof(struct sockaddr_in);
157                 break;
158 #ifdef IPV6_SUPPORTED
159         case AF_INET6:
160                 addrlen = sizeof(struct sockaddr_in6);
161                 break;
162 #endif /* IPV6_SUPPORTED */
163         default:
164                 printerr(0, "ERROR: unrecognized addr family %d\n",
165                          sa->sa_family);
166                 return NULL;
167         }
168
169         err = getnameinfo(sa, addrlen, hbuf, sizeof(hbuf), NULL, 0,
170                           NI_NAMEREQD);
171         if (err) {
172                 printerr(0, "ERROR: unable to resolve %s to hostname: %s\n",
173                          addr, err == EAI_SYSTEM ? strerror(err) :
174                                                    gai_strerror(err));
175                 return NULL;
176         }
177
178         hostname = strdup(hbuf);
179
180         return hostname;
181 }
182
183 /* XXX buffer problems: */
184 static int
185 read_service_info(char *info_file_name, char **servicename, char **servername,
186                   int *prog, int *vers, char **protocol,
187                   struct sockaddr *addr) {
188 #define INFOBUFLEN 256
189         char            buf[INFOBUFLEN + 1];
190         static char     dummy[128];
191         int             nbytes;
192         static char     service[128];
193         static char     address[128];
194         char            program[16];
195         char            version[16];
196         char            protoname[16];
197         char            cb_port[128];
198         char            *p;
199         int             fd = -1;
200         int             numfields;
201         int             port = 0;
202
203         *servicename = *servername = *protocol = NULL;
204
205         if ((fd = open(info_file_name, O_RDONLY)) == -1) {
206                 printerr(0, "ERROR: can't open %s: %s\n", info_file_name,
207                          strerror(errno));
208                 goto fail;
209         }
210         if ((nbytes = read(fd, buf, INFOBUFLEN)) == -1)
211                 goto fail;
212         close(fd);
213         buf[nbytes] = '\0';
214
215         numfields = sscanf(buf,"RPC server: %127s\n"
216                    "service: %127s %15s version %15s\n"
217                    "address: %127s\n"
218                    "protocol: %15s\n",
219                    dummy,
220                    service, program, version,
221                    address,
222                    protoname);
223
224         if (numfields == 5) {
225                 strcpy(protoname, "tcp");
226         } else if (numfields != 6) {
227                 goto fail;
228         }
229
230         cb_port[0] = '\0';
231         if ((p = strstr(buf, "port")) != NULL)
232                 sscanf(p, "port: %127s\n", cb_port);
233
234         /* check service, program, and version */
235         if (memcmp(service, "nfs", 3) != 0)
236                 return -1;
237         *prog = atoi(program + 1); /* skip open paren */
238         *vers = atoi(version);
239
240         if (strlen(service) == 3 ) {
241                 if ((*prog != 100003) || ((*vers != 2) && (*vers != 3) &&
242                     (*vers != 4)))
243                         goto fail;
244         } else if (memcmp(service, "nfs4_cb", 7) == 0) {
245                 if (*vers != 1)
246                         goto fail;
247         }
248
249         if (cb_port[0] != '\0') {
250                 port = atoi(cb_port);
251                 if (port < 0 || port > 65535)
252                         goto fail;
253         }
254
255         if (!addrstr_to_sockaddr(addr, address, port))
256                 goto fail;
257
258         *servername = sockaddr_to_hostname(addr, address);
259         if (*servername == NULL)
260                 goto fail;
261
262         nbytes = snprintf(buf, INFOBUFLEN, "%s@%s", service, *servername);
263         if (nbytes > INFOBUFLEN)
264                 goto fail;
265
266         if (!(*servicename = calloc(strlen(buf) + 1, 1)))
267                 goto fail;
268         memcpy(*servicename, buf, strlen(buf));
269
270         if (!(*protocol = strdup(protoname)))
271                 goto fail;
272         return 0;
273 fail:
274         printerr(0, "ERROR: failed to read service info\n");
275         if (fd != -1) close(fd);
276         free(*servername);
277         free(*servicename);
278         free(*protocol);
279         *servicename = *servername = *protocol = NULL;
280         return -1;
281 }
282
283 static void
284 destroy_client(struct clnt_info *clp)
285 {
286         if (clp->krb5_poll_index != -1)
287                 memset(&pollarray[clp->krb5_poll_index], 0,
288                                         sizeof(struct pollfd));
289         if (clp->spkm3_poll_index != -1)
290                 memset(&pollarray[clp->spkm3_poll_index], 0,
291                                         sizeof(struct pollfd));
292         if (clp->dir_fd != -1) close(clp->dir_fd);
293         if (clp->krb5_fd != -1) close(clp->krb5_fd);
294         if (clp->spkm3_fd != -1) close(clp->spkm3_fd);
295         free(clp->dirname);
296         free(clp->servicename);
297         free(clp->servername);
298         free(clp->protocol);
299         free(clp);
300 }
301
302 static struct clnt_info *
303 insert_new_clnt(void)
304 {
305         struct clnt_info        *clp = NULL;
306
307         if (!(clp = (struct clnt_info *)calloc(1,sizeof(struct clnt_info)))) {
308                 printerr(0, "ERROR: can't malloc clnt_info: %s\n",
309                          strerror(errno));
310                 goto out;
311         }
312         clp->krb5_poll_index = -1;
313         clp->spkm3_poll_index = -1;
314         clp->krb5_fd = -1;
315         clp->spkm3_fd = -1;
316         clp->dir_fd = -1;
317
318         TAILQ_INSERT_HEAD(&clnt_list, clp, list);
319 out:
320         return clp;
321 }
322
323 static int
324 process_clnt_dir_files(struct clnt_info * clp)
325 {
326         char    name[PATH_MAX];
327         char    info_file_name[PATH_MAX];
328
329         if (clp->krb5_fd == -1) {
330                 snprintf(name, sizeof(name), "%s/krb5", clp->dirname);
331                 clp->krb5_fd = open(name, O_RDWR);
332         }
333         if (clp->spkm3_fd == -1) {
334                 snprintf(name, sizeof(name), "%s/spkm3", clp->dirname);
335                 clp->spkm3_fd = open(name, O_RDWR);
336         }
337         if ((clp->krb5_fd == -1) && (clp->spkm3_fd == -1))
338                 return -1;
339         snprintf(info_file_name, sizeof(info_file_name), "%s/info",
340                         clp->dirname);
341         if ((clp->servicename == NULL) &&
342              read_service_info(info_file_name, &clp->servicename,
343                                 &clp->servername, &clp->prog, &clp->vers,
344                                 &clp->protocol, (struct sockaddr *) &clp->addr))
345                 return -1;
346         return 0;
347 }
348
349 static int
350 get_poll_index(int *ind)
351 {
352         int i;
353
354         *ind = -1;
355         for (i=0; i<FD_ALLOC_BLOCK; i++) {
356                 if (pollarray[i].events == 0) {
357                         *ind = i;
358                         break;
359                 }
360         }
361         if (*ind == -1) {
362                 printerr(0, "ERROR: No pollarray slots open\n");
363                 return -1;
364         }
365         return 0;
366 }
367
368
369 static int
370 insert_clnt_poll(struct clnt_info *clp)
371 {
372         if ((clp->krb5_fd != -1) && (clp->krb5_poll_index == -1)) {
373                 if (get_poll_index(&clp->krb5_poll_index)) {
374                         printerr(0, "ERROR: Too many krb5 clients\n");
375                         return -1;
376                 }
377                 pollarray[clp->krb5_poll_index].fd = clp->krb5_fd;
378                 pollarray[clp->krb5_poll_index].events |= POLLIN;
379         }
380
381         if ((clp->spkm3_fd != -1) && (clp->spkm3_poll_index == -1)) {
382                 if (get_poll_index(&clp->spkm3_poll_index)) {
383                         printerr(0, "ERROR: Too many spkm3 clients\n");
384                         return -1;
385                 }
386                 pollarray[clp->spkm3_poll_index].fd = clp->spkm3_fd;
387                 pollarray[clp->spkm3_poll_index].events |= POLLIN;
388         }
389
390         return 0;
391 }
392
393 static void
394 process_clnt_dir(char *dir, char *pdir)
395 {
396         struct clnt_info *      clp;
397
398         if (!(clp = insert_new_clnt()))
399                 goto fail_destroy_client;
400
401         /* An extra for the '/', and an extra for the null */
402         if (!(clp->dirname = calloc(strlen(dir) + strlen(pdir) + 2, 1))) {
403                 goto fail_destroy_client;
404         }
405         sprintf(clp->dirname, "%s/%s", pdir, dir);
406         if ((clp->dir_fd = open(clp->dirname, O_RDONLY)) == -1) {
407                 printerr(0, "ERROR: can't open %s: %s\n",
408                          clp->dirname, strerror(errno));
409                 goto fail_destroy_client;
410         }
411         fcntl(clp->dir_fd, F_SETSIG, DNOTIFY_SIGNAL);
412         fcntl(clp->dir_fd, F_NOTIFY, DN_CREATE | DN_DELETE | DN_MULTISHOT);
413
414         if (process_clnt_dir_files(clp))
415                 goto fail_keep_client;
416
417         if (insert_clnt_poll(clp))
418                 goto fail_destroy_client;
419
420         return;
421
422 fail_destroy_client:
423         if (clp) {
424                 TAILQ_REMOVE(&clnt_list, clp, list);
425                 destroy_client(clp);
426         }
427 fail_keep_client:
428         /* We couldn't find some subdirectories, but we keep the client
429          * around in case we get a notification on the directory when the
430          * subdirectories are created. */
431         return;
432 }
433
434 void
435 init_client_list(void)
436 {
437         TAILQ_INIT(&clnt_list);
438         /* Eventually plan to grow/shrink poll array: */
439         pollsize = FD_ALLOC_BLOCK;
440         pollarray = calloc(pollsize, sizeof(struct pollfd));
441 }
442
443 /*
444  * This is run after a DNOTIFY signal, and should clear up any
445  * directories that are no longer around, and re-scan any existing
446  * directories, since the DNOTIFY could have been in there.
447  */
448 static void
449 update_old_clients(struct dirent **namelist, int size, char *pdir)
450 {
451         struct clnt_info *clp;
452         void *saveprev;
453         int i, stillhere;
454         char fname[PATH_MAX];
455
456         for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
457                 /* only compare entries in the global list that are from the
458                  * same pipefs parent directory as "pdir"
459                  */
460                 if (strncmp(clp->dirname, pdir, strlen(pdir)) != 0) continue;
461
462                 stillhere = 0;
463                 for (i=0; i < size; i++) {
464                         snprintf(fname, sizeof(fname), "%s/%s",
465                                  pdir, namelist[i]->d_name);
466                         if (strcmp(clp->dirname, fname) == 0) {
467                                 stillhere = 1;
468                                 break;
469                         }
470                 }
471                 if (!stillhere) {
472                         printerr(2, "destroying client %s\n", clp->dirname);
473                         saveprev = clp->list.tqe_prev;
474                         TAILQ_REMOVE(&clnt_list, clp, list);
475                         destroy_client(clp);
476                         clp = saveprev;
477                 }
478         }
479         for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
480                 if (!process_clnt_dir_files(clp))
481                         insert_clnt_poll(clp);
482         }
483 }
484
485 /* Search for a client by directory name, return 1 if found, 0 otherwise */
486 static int
487 find_client(char *dirname, char *pdir)
488 {
489         struct clnt_info        *clp;
490         char fname[PATH_MAX];
491
492         for (clp = clnt_list.tqh_first; clp != NULL; clp = clp->list.tqe_next) {
493                 snprintf(fname, sizeof(fname), "%s/%s", pdir, dirname);
494                 if (strcmp(clp->dirname, fname) == 0)
495                         return 1;
496         }
497         return 0;
498 }
499
500 static int
501 process_pipedir(char *pipe_name)
502 {
503         struct dirent **namelist;
504         int i, j;
505
506         if (chdir(pipe_name) < 0) {
507                 printerr(0, "ERROR: can't chdir to %s: %s\n",
508                          pipe_name, strerror(errno));
509                 return -1;
510         }
511
512         j = scandir(pipe_name, &namelist, NULL, alphasort);
513         if (j < 0) {
514                 printerr(0, "ERROR: can't scandir %s: %s\n",
515                          pipe_name, strerror(errno));
516                 return -1;
517         }
518
519         update_old_clients(namelist, j, pipe_name);
520         for (i=0; i < j; i++) {
521                 if (i < FD_ALLOC_BLOCK
522                                 && !strncmp(namelist[i]->d_name, "clnt", 4)
523                                 && !find_client(namelist[i]->d_name, pipe_name))
524                         process_clnt_dir(namelist[i]->d_name, pipe_name);
525                 free(namelist[i]);
526         }
527
528         free(namelist);
529
530         return 0;
531 }
532
533 /* Used to read (and re-read) list of clients, set up poll array. */
534 int
535 update_client_list(void)
536 {
537         int retval = -1;
538         struct topdirs_info *tdi;
539
540         TAILQ_FOREACH(tdi, &topdirs_list, list) {
541                 retval = process_pipedir(tdi->dirname);
542                 if (retval)
543                         printerr(1, "WARNING: error processing %s\n",
544                                  tdi->dirname);
545
546         }
547         return retval;
548 }
549
550 static int
551 do_downcall(int k5_fd, uid_t uid, struct authgss_private_data *pd,
552             gss_buffer_desc *context_token)
553 {
554         char    *buf = NULL, *p = NULL, *end = NULL;
555         unsigned int timeout = context_timeout;
556         unsigned int buf_size = 0;
557
558         printerr(1, "doing downcall\n");
559         buf_size = sizeof(uid) + sizeof(timeout) + sizeof(pd->pd_seq_win) +
560                 sizeof(pd->pd_ctx_hndl.length) + pd->pd_ctx_hndl.length +
561                 sizeof(context_token->length) + context_token->length;
562         p = buf = malloc(buf_size);
563         end = buf + buf_size;
564
565         if (WRITE_BYTES(&p, end, uid)) goto out_err;
566         if (WRITE_BYTES(&p, end, timeout)) goto out_err;
567         if (WRITE_BYTES(&p, end, pd->pd_seq_win)) goto out_err;
568         if (write_buffer(&p, end, &pd->pd_ctx_hndl)) goto out_err;
569         if (write_buffer(&p, end, context_token)) goto out_err;
570
571         if (write(k5_fd, buf, p - buf) < p - buf) goto out_err;
572         if (buf) free(buf);
573         return 0;
574 out_err:
575         if (buf) free(buf);
576         printerr(1, "Failed to write downcall!\n");
577         return -1;
578 }
579
580 static int
581 do_error_downcall(int k5_fd, uid_t uid, int err)
582 {
583         char    buf[1024];
584         char    *p = buf, *end = buf + 1024;
585         unsigned int timeout = 0;
586         int     zero = 0;
587
588         printerr(1, "doing error downcall\n");
589
590         if (WRITE_BYTES(&p, end, uid)) goto out_err;
591         if (WRITE_BYTES(&p, end, timeout)) goto out_err;
592         /* use seq_win = 0 to indicate an error: */
593         if (WRITE_BYTES(&p, end, zero)) goto out_err;
594         if (WRITE_BYTES(&p, end, err)) goto out_err;
595
596         if (write(k5_fd, buf, p - buf) < p - buf) goto out_err;
597         return 0;
598 out_err:
599         printerr(1, "Failed to write error downcall!\n");
600         return -1;
601 }
602
603 /*
604  * If the port isn't already set, do an rpcbind query to the remote server
605  * using the program and version and get the port. 
606  *
607  * Newer kernels send the value of the port= mount option in the "info"
608  * file for the upcall or '0' for NFSv2/3. For NFSv4 it sends the value
609  * of the port= option or '2049'. The port field in a new sockaddr should
610  * reflect the value that was sent by the kernel.
611  */
612 static int
613 populate_port(struct sockaddr *sa, const socklen_t salen,
614               const rpcprog_t program, const rpcvers_t version,
615               const unsigned short protocol)
616 {
617         struct sockaddr_in      *s4 = (struct sockaddr_in *) sa;
618 #ifdef IPV6_SUPPORTED
619         struct sockaddr_in6     *s6 = (struct sockaddr_in6 *) sa;
620 #endif /* IPV6_SUPPORTED */
621         unsigned short          port;
622
623         /*
624          * Newer kernels send the port in the upcall. If we already have
625          * the port, there's no need to look it up.
626          */
627         switch (sa->sa_family) {
628         case AF_INET:
629                 if (s4->sin_port != 0) {
630                         printerr(2, "DEBUG: port already set to %d\n",
631                                  ntohs(s4->sin_port));
632                         return 1;
633                 }
634                 break;
635 #ifdef IPV6_SUPPORTED
636         case AF_INET6:
637                 if (s6->sin6_port != 0) {
638                         printerr(2, "DEBUG: port already set to %d\n",
639                                  ntohs(s6->sin6_port));
640                         return 1;
641                 }
642                 break;
643 #endif /* IPV6_SUPPORTED */
644         default:
645                 printerr(0, "ERROR: unsupported address family %d\n",
646                             sa->sa_family);
647                 return 0;
648         }
649
650         /*
651          * Newer kernels that send the port in the upcall set the value to
652          * 2049 for NFSv4 mounts when one isn't specified. The check below is
653          * only for kernels that don't send the port in the upcall. For those
654          * we either have to do an rpcbind query or set it to the standard
655          * port. Doing a query could be problematic (firewalls, etc), so take
656          * the latter approach.
657          */
658         if (program == 100003 && version == 4) {
659                 port = 2049;
660                 goto set_port;
661         }
662
663         port = nfs_getport(sa, salen, program, version, protocol);
664         if (!port) {
665                 printerr(0, "ERROR: unable to obtain port for prog %ld "
666                             "vers %ld\n", program, version);
667                 return 0;
668         }
669
670 set_port:
671         printerr(2, "DEBUG: setting port to %hu for prog %lu vers %lu\n", port,
672                  program, version);
673
674         switch (sa->sa_family) {
675         case AF_INET:
676                 s4->sin_port = htons(port);
677                 break;
678 #ifdef IPV6_SUPPORTED
679         case AF_INET6:
680                 s6->sin6_port = htons(port);
681                 break;
682 #endif /* IPV6_SUPPORTED */
683         }
684
685         return 1;
686 }
687
688 /*
689  * Create an RPC connection and establish an authenticated
690  * gss context with a server.
691  */
692 int create_auth_rpc_client(struct clnt_info *clp,
693                            CLIENT **clnt_return,
694                            AUTH **auth_return,
695                            uid_t uid,
696                            int authtype)
697 {
698         CLIENT                  *rpc_clnt = NULL;
699         struct rpc_gss_sec      sec;
700         AUTH                    *auth = NULL;
701         uid_t                   save_uid = -1;
702         int                     retval = -1;
703         OM_uint32               min_stat;
704         char                    rpc_errmsg[1024];
705         int                     protocol;
706         struct timeval          timeout = {5, 0};
707         struct sockaddr         *addr = (struct sockaddr *) &clp->addr;
708         socklen_t               salen;
709
710         /* Create the context as the user (not as root) */
711         save_uid = geteuid();
712         if (setfsuid(uid) != 0) {
713                 printerr(0, "WARNING: Failed to setfsuid for "
714                             "user with uid %d\n", uid);
715                 goto out_fail;
716         }
717         printerr(2, "creating context using fsuid %d (save_uid %d)\n",
718                         uid, save_uid);
719
720         sec.qop = GSS_C_QOP_DEFAULT;
721         sec.svc = RPCSEC_GSS_SVC_NONE;
722         sec.cred = GSS_C_NO_CREDENTIAL;
723         sec.req_flags = 0;
724         if (authtype == AUTHTYPE_KRB5) {
725                 sec.mech = (gss_OID)&krb5oid;
726                 sec.req_flags = GSS_C_MUTUAL_FLAG;
727         }
728         else if (authtype == AUTHTYPE_SPKM3) {
729                 sec.mech = (gss_OID)&spkm3oid;
730                 /* XXX sec.req_flags = GSS_C_ANON_FLAG;
731                  * Need a way to switch....
732                  */
733                 sec.req_flags = GSS_C_MUTUAL_FLAG;
734         }
735         else {
736                 printerr(0, "ERROR: Invalid authentication type (%d) "
737                         "in create_auth_rpc_client\n", authtype);
738                 goto out_fail;
739         }
740
741
742         if (authtype == AUTHTYPE_KRB5) {
743 #ifdef HAVE_SET_ALLOWABLE_ENCTYPES
744                 /*
745                  * Do this before creating rpc connection since we won't need
746                  * rpc connection if it fails!
747                  */
748                 if (limit_krb5_enctypes(&sec, uid)) {
749                         printerr(1, "WARNING: Failed while limiting krb5 "
750                                     "encryption types for user with uid %d\n",
751                                  uid);
752                         goto out_fail;
753                 }
754 #endif
755         }
756
757         /* create an rpc connection to the nfs server */
758
759         printerr(2, "creating %s client for server %s\n", clp->protocol,
760                         clp->servername);
761
762         if ((strcmp(clp->protocol, "tcp")) == 0) {
763                 protocol = IPPROTO_TCP;
764         } else if ((strcmp(clp->protocol, "udp")) == 0) {
765                 protocol = IPPROTO_UDP;
766         } else {
767                 printerr(0, "WARNING: unrecognized protocol, '%s', requested "
768                          "for connection to server %s for user with uid %d\n",
769                          clp->protocol, clp->servername, uid);
770                 goto out_fail;
771         }
772
773         switch (addr->sa_family) {
774         case AF_INET:
775                 salen = sizeof(struct sockaddr_in);
776                 break;
777 #ifdef IPV6_SUPPORTED
778         case AF_INET6:
779                 salen = sizeof(struct sockaddr_in6);
780                 break;
781 #endif /* IPV6_SUPPORTED */
782         default:
783                 printerr(1, "ERROR: Unknown address family %d\n",
784                          addr->sa_family);
785                 goto out_fail;
786         }
787
788         if (!populate_port(addr, salen, clp->prog, clp->vers, protocol))
789                 goto out_fail;
790
791         rpc_clnt = nfs_get_rpcclient(addr, salen, protocol, clp->prog,
792                                      clp->vers, &timeout);
793         if (!rpc_clnt) {
794                 snprintf(rpc_errmsg, sizeof(rpc_errmsg),
795                          "WARNING: can't create %s rpc_clnt to server %s for "
796                          "user with uid %d",
797                          protocol == IPPROTO_TCP ? "tcp" : "udp",
798                          clp->servername, uid);
799                 printerr(0, "%s\n",
800                          clnt_spcreateerror(rpc_errmsg));
801                 goto out_fail;
802         }
803
804         printerr(2, "creating context with server %s\n", clp->servicename);
805         auth = authgss_create_default(rpc_clnt, clp->servicename, &sec);
806         if (!auth) {
807                 /* Our caller should print appropriate message */
808                 printerr(2, "WARNING: Failed to create %s context for "
809                             "user with uid %d for server %s\n",
810                         (authtype == AUTHTYPE_KRB5 ? "krb5":"spkm3"),
811                          uid, clp->servername);
812                 goto out_fail;
813         }
814
815         /* Success !!! */
816         rpc_clnt->cl_auth = auth;
817         *clnt_return = rpc_clnt;
818         *auth_return = auth;
819         retval = 0;
820
821   out:
822         if (sec.cred != GSS_C_NO_CREDENTIAL)
823                 gss_release_cred(&min_stat, &sec.cred);
824         /* Restore euid to original value */
825         if ((save_uid != -1) && (setfsuid(save_uid) != uid)) {
826                 printerr(0, "WARNING: Failed to restore fsuid"
827                             " to uid %d from %d\n", save_uid, uid);
828         }
829         return retval;
830
831   out_fail:
832         /* Only destroy here if failure.  Otherwise, caller is responsible */
833         if (rpc_clnt) clnt_destroy(rpc_clnt);
834
835         goto out;
836 }
837
838
839 /*
840  * this code uses the userland rpcsec gss library to create a krb5
841  * context on behalf of the kernel
842  */
843 void
844 handle_krb5_upcall(struct clnt_info *clp)
845 {
846         uid_t                   uid;
847         CLIENT                  *rpc_clnt = NULL;
848         AUTH                    *auth = NULL;
849         struct authgss_private_data pd;
850         gss_buffer_desc         token;
851         char                    **credlist = NULL;
852         char                    **ccname;
853         char                    **dirname;
854         int                     create_resp = -1;
855
856         printerr(1, "handling krb5 upcall\n");
857
858         token.length = 0;
859         token.value = NULL;
860         memset(&pd, 0, sizeof(struct authgss_private_data));
861
862         if (read(clp->krb5_fd, &uid, sizeof(uid)) < sizeof(uid)) {
863                 printerr(0, "WARNING: failed reading uid from krb5 "
864                             "upcall pipe: %s\n", strerror(errno));
865                 goto out;
866         }
867
868         if (uid != 0 || (uid == 0 && root_uses_machine_creds == 0)) {
869                 /* Tell krb5 gss which credentials cache to use */
870                 for (dirname = ccachesearch; *dirname != NULL; dirname++) {
871                         if (gssd_setup_krb5_user_gss_ccache(uid, clp->servername, *dirname) == 0)
872                                 create_resp = create_auth_rpc_client(clp, &rpc_clnt, &auth, uid,
873                                                              AUTHTYPE_KRB5);
874                         if (create_resp == 0)
875                                 break;
876                 }
877         }
878         if (create_resp != 0) {
879                 if (uid == 0 && root_uses_machine_creds == 1) {
880                         int nocache = 0;
881                         int success = 0;
882                         do {
883                                 gssd_refresh_krb5_machine_credential(clp->servername,
884                                                                      NULL, nocache);
885                                 /*
886                                  * Get a list of credential cache names and try each
887                                  * of them until one works or we've tried them all
888                                  */
889                                 if (gssd_get_krb5_machine_cred_list(&credlist)) {
890                                         printerr(0, "ERROR: No credentials found "
891                                                  "for connection to server %s\n",
892                                                  clp->servername);
893                                                 goto out_return_error;
894                                 }
895                                 for (ccname = credlist; ccname && *ccname; ccname++) {
896                                         gssd_setup_krb5_machine_gss_ccache(*ccname);
897                                         if ((create_auth_rpc_client(clp, &rpc_clnt,
898                                                                     &auth, uid,
899                                                                     AUTHTYPE_KRB5)) == 0) {
900                                                 /* Success! */
901                                                 success++;
902                                                 break;
903                                         } 
904                                         printerr(2, "WARNING: Failed to create machine krb5 context "
905                                                  "with credentials cache %s for server %s\n",
906                                                  *ccname, clp->servername);
907                                 }
908                                 gssd_free_krb5_machine_cred_list(credlist);                     
909                                 if (!success) {
910                                         if(nocache == 0) {
911                                                 nocache++;
912                                                 printerr(2, "WARNING: Machine cache is prematurely expired or corrupted "
913                                                             "trying to recreate cache for server %s\n", clp->servername);
914                                         } else {
915                                                 printerr(1, "WARNING: Failed to create machine krb5 context "
916                                                  "with any credentials cache for server %s\n",
917                                                  clp->servername);
918                                                 goto out_return_error;
919                                         }
920                                 }
921                         } while(!success);
922                 } else {
923                         printerr(1, "WARNING: Failed to create krb5 context "
924                                  "for user with uid %d for server %s\n",
925                                  uid, clp->servername);
926                         goto out_return_error;
927                 }
928         }
929
930         if (!authgss_get_private_data(auth, &pd)) {
931                 printerr(1, "WARNING: Failed to obtain authentication "
932                             "data for user with uid %d for server %s\n",
933                          uid, clp->servername);
934                 goto out_return_error;
935         }
936
937         if (serialize_context_for_kernel(pd.pd_ctx, &token, &krb5oid, NULL)) {
938                 printerr(0, "WARNING: Failed to serialize krb5 context for "
939                             "user with uid %d for server %s\n",
940                          uid, clp->servername);
941                 goto out_return_error;
942         }
943
944         do_downcall(clp->krb5_fd, uid, &pd, &token);
945
946 out:
947         if (token.value)
948                 free(token.value);
949 #ifndef HAVE_LIBTIRPC
950         if (pd.pd_ctx_hndl.length != 0)
951                 authgss_free_private_data(&pd);
952 #endif
953         if (auth)
954                 AUTH_DESTROY(auth);
955         if (rpc_clnt)
956                 clnt_destroy(rpc_clnt);
957         return;
958
959 out_return_error:
960         do_error_downcall(clp->krb5_fd, uid, -1);
961         goto out;
962 }
963
964 /*
965  * this code uses the userland rpcsec gss library to create an spkm3
966  * context on behalf of the kernel
967  */
968 void
969 handle_spkm3_upcall(struct clnt_info *clp)
970 {
971         uid_t                   uid;
972         CLIENT                  *rpc_clnt = NULL;
973         AUTH                    *auth = NULL;
974         struct authgss_private_data pd;
975         gss_buffer_desc         token;
976
977         printerr(2, "handling spkm3 upcall\n");
978
979         token.length = 0;
980         token.value = NULL;
981
982         if (read(clp->spkm3_fd, &uid, sizeof(uid)) < sizeof(uid)) {
983                 printerr(0, "WARNING: failed reading uid from spkm3 "
984                          "upcall pipe: %s\n", strerror(errno));
985                 goto out;
986         }
987
988         if (create_auth_rpc_client(clp, &rpc_clnt, &auth, uid, AUTHTYPE_SPKM3)) {
989                 printerr(0, "WARNING: Failed to create spkm3 context for "
990                             "user with uid %d\n", uid);
991                 goto out_return_error;
992         }
993
994         if (!authgss_get_private_data(auth, &pd)) {
995                 printerr(0, "WARNING: Failed to obtain authentication "
996                             "data for user with uid %d for server %s\n",
997                          uid, clp->servername);
998                 goto out_return_error;
999         }
1000
1001         if (serialize_context_for_kernel(pd.pd_ctx, &token, &spkm3oid, NULL)) {
1002                 printerr(0, "WARNING: Failed to serialize spkm3 context for "
1003                             "user with uid %d for server\n",
1004                          uid, clp->servername);
1005                 goto out_return_error;
1006         }
1007
1008         do_downcall(clp->spkm3_fd, uid, &pd, &token);
1009
1010 out:
1011         if (token.value)
1012                 free(token.value);
1013         if (auth)
1014                 AUTH_DESTROY(auth);
1015         if (rpc_clnt)
1016                 clnt_destroy(rpc_clnt);
1017         return;
1018
1019 out_return_error:
1020         do_error_downcall(clp->spkm3_fd, uid, -1);
1021         goto out;
1022 }