svcgssd: use the actual context expiration for cache
[nfs-utils.git] / utils / gssd / svcgssd_proc.c
1 /*
2   svc_in_gssd_proc.c
3
4   Copyright (c) 2000 The Regents of the University of Michigan.
5   All rights reserved.
6
7   Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
8
9   Redistribution and use in source and binary forms, with or without
10   modification, are permitted provided that the following conditions
11   are met:
12
13   1. Redistributions of source code must retain the above copyright
14      notice, this list of conditions and the following disclaimer.
15   2. Redistributions in binary form must reproduce the above copyright
16      notice, this list of conditions and the following disclaimer in the
17      documentation and/or other materials provided with the distribution.
18   3. Neither the name of the University nor the names of its
19      contributors may be used to endorse or promote products derived
20      from this software without specific prior written permission.
21
22   THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
23   WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
24   MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25   DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
26   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
29   BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34 */
35
36 #include <sys/param.h>
37 #include <sys/stat.h>
38 #include <rpc/rpc.h>
39
40 #include <pwd.h>
41 #include <stdio.h>
42 #include <unistd.h>
43 #include <ctype.h>
44 #include <string.h>
45 #include <fcntl.h>
46 #include <errno.h>
47 #include <nfsidmap.h>
48 #include <nfslib.h>
49 #include <time.h>
50
51 #include "svcgssd.h"
52 #include "gss_util.h"
53 #include "err_util.h"
54 #include "context.h"
55
56 extern char * mech2file(gss_OID mech);
57 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.rpcsec.context/channel"
58 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.rpcsec.init/channel"
59
60 #define TOKEN_BUF_SIZE          8192
61
62 struct svc_cred {
63         uid_t   cr_uid;
64         gid_t   cr_gid;
65         int     cr_ngroups;
66         gid_t   cr_groups[NGROUPS];
67 };
68
69 static int
70 do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
71                 gss_OID mech, gss_buffer_desc *context_token,
72                 int32_t endtime)
73 {
74         FILE *f;
75         int i;
76         char *fname = NULL;
77         int err;
78
79         printerr(1, "doing downcall\n");
80         if ((fname = mech2file(mech)) == NULL)
81                 goto out_err;
82         f = fopen(SVCGSSD_CONTEXT_CHANNEL, "w");
83         if (f == NULL) {
84                 printerr(0, "WARNING: unable to open downcall channel "
85                              "%s: %s\n",
86                              SVCGSSD_CONTEXT_CHANNEL, strerror(errno));
87                 goto out_err;
88         }
89         qword_printhex(f, out_handle->value, out_handle->length);
90         /* XXX are types OK for the rest of this? */
91         /* For context cache, use the actual context endtime */
92         qword_printint(f, endtime);
93         qword_printint(f, cred->cr_uid);
94         qword_printint(f, cred->cr_gid);
95         qword_printint(f, cred->cr_ngroups);
96         printerr(2, "mech: %s, hndl len: %d, ctx len %d, timeout: %d (%d from now), "
97                  "uid: %d, gid: %d, num aux grps: %d:\n",
98                  fname, out_handle->length, context_token->length,
99                  endtime, endtime - time(0),
100                  cred->cr_uid, cred->cr_gid, cred->cr_ngroups);
101         for (i=0; i < cred->cr_ngroups; i++) {
102                 qword_printint(f, cred->cr_groups[i]);
103                 printerr(2, "  (%4d) %d\n", i+1, cred->cr_groups[i]);
104         }
105         qword_print(f, fname);
106         qword_printhex(f, context_token->value, context_token->length);
107         err = qword_eol(f);
108         fclose(f);
109         return err;
110 out_err:
111         printerr(0, "WARNING: downcall failed\n");
112         return -1;
113 }
114
115 struct gss_verifier {
116         u_int32_t       flav;
117         gss_buffer_desc body;
118 };
119
120 #define RPCSEC_GSS_SEQ_WIN      5
121
122 static int
123 send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
124               u_int32_t maj_stat, u_int32_t min_stat,
125               gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
126 {
127         char buf[2 * TOKEN_BUF_SIZE];
128         char *bp = buf;
129         int blen = sizeof(buf);
130         /* XXXARG: */
131         int g;
132
133         printerr(1, "sending null reply\n");
134
135         qword_addhex(&bp, &blen, in_handle->value, in_handle->length);
136         qword_addhex(&bp, &blen, in_token->value, in_token->length);
137         /* For init cache, only needed for a short time */
138         qword_addint(&bp, &blen, time(0) + 60);
139         qword_adduint(&bp, &blen, maj_stat);
140         qword_adduint(&bp, &blen, min_stat);
141         qword_addhex(&bp, &blen, out_handle->value, out_handle->length);
142         qword_addhex(&bp, &blen, out_token->value, out_token->length);
143         qword_addeol(&bp, &blen);
144         if (blen <= 0) {
145                 printerr(0, "WARNING: send_respsonse: message too long\n");
146                 return -1;
147         }
148         g = open(SVCGSSD_INIT_CHANNEL, O_WRONLY);
149         if (g == -1) {
150                 printerr(0, "WARNING: open %s failed: %s\n",
151                                 SVCGSSD_INIT_CHANNEL, strerror(errno));
152                 return -1;
153         }
154         *bp = '\0';
155         printerr(3, "writing message: %s", buf);
156         if (write(g, buf, bp - buf) == -1) {
157                 printerr(0, "WARNING: failed to write message\n");
158                 close(g);
159                 return -1;
160         }
161         close(g);
162         return 0;
163 }
164
165 #define rpc_auth_ok                     0
166 #define rpc_autherr_badcred             1
167 #define rpc_autherr_rejectedcred        2
168 #define rpc_autherr_badverf             3
169 #define rpc_autherr_rejectedverf        4
170 #define rpc_autherr_tooweak             5
171 #define rpcsec_gsserr_credproblem       13
172 #define rpcsec_gsserr_ctxproblem        14
173
174 static void
175 add_supplementary_groups(char *secname, char *name, struct svc_cred *cred)
176 {
177         int ret;
178         static gid_t *groups = NULL;
179
180         cred->cr_ngroups = NGROUPS;
181         ret = nfs4_gss_princ_to_grouplist(secname, name,
182                         cred->cr_groups, &cred->cr_ngroups);
183         if (ret < 0) {
184                 groups = realloc(groups, cred->cr_ngroups*sizeof(gid_t));
185                 ret = nfs4_gss_princ_to_grouplist(secname, name,
186                                 groups, &cred->cr_ngroups);
187                 if (ret < 0)
188                         cred->cr_ngroups = 0;
189                 else {
190                         if (cred->cr_ngroups > NGROUPS)
191                                 cred->cr_ngroups = NGROUPS;
192                         memcpy(cred->cr_groups, groups,
193                                         cred->cr_ngroups*sizeof(gid_t));
194                 }
195         }
196 }
197
198 static int
199 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred)
200 {
201         u_int32_t       maj_stat, min_stat;
202         gss_buffer_desc name;
203         char            *sname;
204         int             res = -1;
205         uid_t           uid, gid;
206         gss_OID         name_type = GSS_C_NO_OID;
207         char            *secname;
208
209         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
210         if (maj_stat != GSS_S_COMPLETE) {
211                 pgsserr("get_ids: gss_display_name",
212                         maj_stat, min_stat, mech);
213                 goto out;
214         }
215         if (name.length >= 0xffff || /* be certain name.length+1 doesn't overflow */
216             !(sname = calloc(name.length + 1, 1))) {
217                 printerr(0, "WARNING: get_ids: error allocating %d bytes "
218                         "for sname\n", name.length + 1);
219                 gss_release_buffer(&min_stat, &name);
220                 goto out;
221         }
222         memcpy(sname, name.value, name.length);
223         printerr(1, "sname = %s\n", sname);
224         gss_release_buffer(&min_stat, &name);
225
226         res = -EINVAL;
227         if ((secname = mech2file(mech)) == NULL) {
228                 printerr(0, "WARNING: get_ids: error mapping mech to "
229                         "file for name '%s'\n", sname);
230                 goto out_free;
231         }
232         nfs4_init_name_mapping(NULL); /* XXX: should only do this once */
233         res = nfs4_gss_princ_to_ids(secname, sname, &uid, &gid);
234         if (res < 0) {
235                 /*
236                  * -ENOENT means there was no mapping, any other error
237                  * value means there was an error trying to do the
238                  * mapping.
239                  * If there was no mapping, we send down the value -1
240                  * to indicate that the anonuid/anongid for the export
241                  * should be used.
242                  */
243                 if (res == -ENOENT) {
244                         cred->cr_uid = -1;
245                         cred->cr_gid = -1;
246                         cred->cr_ngroups = 0;
247                         res = 0;
248                         goto out_free;
249                 }
250                 printerr(0, "WARNING: get_ids: failed to map name '%s' "
251                         "to uid/gid: %s\n", sname, strerror(-res));
252                 goto out_free;
253         }
254         cred->cr_uid = uid;
255         cred->cr_gid = gid;
256         add_supplementary_groups(secname, sname, cred);
257         res = 0;
258 out_free:
259         free(sname);
260 out:
261         return res;
262 }
263
264 #ifdef DEBUG
265 void
266 print_hexl(const char *description, unsigned char *cp, int length)
267 {
268         int i, j, jm;
269         unsigned char c;
270
271         printf("%s (length %d)\n", description, length);
272
273         for (i = 0; i < length; i += 0x10) {
274                 printf("  %04x: ", (u_int)i);
275                 jm = length - i;
276                 jm = jm > 16 ? 16 : jm;
277
278                 for (j = 0; j < jm; j++) {
279                         if ((j % 2) == 1)
280                                 printf("%02x ", (u_int)cp[i+j]);
281                         else
282                                 printf("%02x", (u_int)cp[i+j]);
283                 }
284                 for (; j < 16; j++) {
285                         if ((j % 2) == 1)
286                                 printf("   ");
287                         else
288                                 printf("  ");
289                 }
290                 printf(" ");
291
292                 for (j = 0; j < jm; j++) {
293                         c = cp[i+j];
294                         c = isprint(c) ? c : '.';
295                         printf("%c", c);
296                 }
297                 printf("\n");
298         }
299 }
300 #endif
301
302 void
303 handle_nullreq(FILE *f) {
304         /* XXX initialize to a random integer to reduce chances of unnecessary
305          * invalidation of existing ctx's on restarting svcgssd. */
306         static u_int32_t        handle_seq = 0;
307         char                    in_tok_buf[TOKEN_BUF_SIZE];
308         char                    in_handle_buf[15];
309         char                    out_handle_buf[15];
310         gss_buffer_desc         in_tok = {.value = in_tok_buf},
311                                 out_tok = {.value = NULL},
312                                 in_handle = {.value = in_handle_buf},
313                                 out_handle = {.value = out_handle_buf},
314                                 ctx_token = {.value = NULL},
315                                 ignore_out_tok = {.value = NULL},
316         /* XXX isn't there a define for this?: */
317                                 null_token = {.value = NULL};
318         u_int32_t               ret_flags;
319         gss_ctx_id_t            ctx = GSS_C_NO_CONTEXT;
320         gss_name_t              client_name;
321         gss_OID                 mech = GSS_C_NO_OID;
322         u_int32_t               maj_stat = GSS_S_FAILURE, min_stat = 0;
323         u_int32_t               ignore_min_stat;
324         struct svc_cred         cred;
325         static char             *lbuf = NULL;
326         static int              lbuflen = 0;
327         static char             *cp;
328         int32_t                 ctx_endtime;
329
330         printerr(1, "handling null request\n");
331
332         if (readline(fileno(f), &lbuf, &lbuflen) != 1) {
333                 printerr(0, "WARNING: handle_nullreq: "
334                             "failed reading request\n");
335                 return;
336         }
337
338         cp = lbuf;
339
340         in_handle.length = (size_t) qword_get(&cp, in_handle.value,
341                                               sizeof(in_handle_buf));
342 #ifdef DEBUG
343         print_hexl("in_handle", in_handle.value, in_handle.length);
344 #endif
345
346         in_tok.length = (size_t) qword_get(&cp, in_tok.value,
347                                            sizeof(in_tok_buf));
348 #ifdef DEBUG
349         print_hexl("in_tok", in_tok.value, in_tok.length);
350 #endif
351
352         if (in_tok.length < 0) {
353                 printerr(0, "WARNING: handle_nullreq: "
354                             "failed parsing request\n");
355                 goto out_err;
356         }
357
358         if (in_handle.length != 0) { /* CONTINUE_INIT case */
359                 if (in_handle.length != sizeof(ctx)) {
360                         printerr(0, "WARNING: handle_nullreq: "
361                                     "input handle has unexpected length %d\n",
362                                     in_handle.length);
363                         goto out_err;
364                 }
365                 /* in_handle is the context id stored in the out_handle
366                  * for the GSS_S_CONTINUE_NEEDED case below.  */
367                 memcpy(&ctx, in_handle.value, in_handle.length);
368         }
369
370         maj_stat = gss_accept_sec_context(&min_stat, &ctx, gssd_creds,
371                         &in_tok, GSS_C_NO_CHANNEL_BINDINGS, &client_name,
372                         &mech, &out_tok, &ret_flags, NULL, NULL);
373
374         if (maj_stat == GSS_S_CONTINUE_NEEDED) {
375                 printerr(1, "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
376
377                 /* Save the context handle for future calls */
378                 out_handle.length = sizeof(ctx);
379                 memcpy(out_handle.value, &ctx, sizeof(ctx));
380                 goto continue_needed;
381         }
382         else if (maj_stat != GSS_S_COMPLETE) {
383                 printerr(0, "WARNING: gss_accept_sec_context failed\n");
384                 pgsserr("handle_nullreq: gss_accept_sec_context",
385                         maj_stat, min_stat, mech);
386                 goto out_err;
387         }
388         if (get_ids(client_name, mech, &cred)) {
389                 /* get_ids() prints error msg */
390                 maj_stat = GSS_S_BAD_NAME; /* XXX ? */
391                 gss_release_name(&ignore_min_stat, &client_name);
392                 goto out_err;
393         }
394         gss_release_name(&ignore_min_stat, &client_name);
395
396
397         /* Context complete. Pass handle_seq in out_handle to use
398          * for context lookup in the kernel. */
399         handle_seq++;
400         out_handle.length = sizeof(handle_seq);
401         memcpy(out_handle.value, &handle_seq, sizeof(handle_seq));
402
403         /* kernel needs ctx to calculate verifier on null response, so
404          * must give it context before doing null call: */
405         if (serialize_context_for_kernel(ctx, &ctx_token, mech, &ctx_endtime)) {
406                 printerr(0, "WARNING: handle_nullreq: "
407                             "serialize_context_for_kernel failed\n");
408                 maj_stat = GSS_S_FAILURE;
409                 goto out_err;
410         }
411         /* We no longer need the gss context */
412         gss_delete_sec_context(&ignore_min_stat, &ctx, &ignore_out_tok);
413
414         do_svc_downcall(&out_handle, &cred, mech, &ctx_token, ctx_endtime);
415 continue_needed:
416         send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
417                         &out_handle, &out_tok);
418 out:
419         if (ctx_token.value != NULL)
420                 free(ctx_token.value);
421         if (out_tok.value != NULL)
422                 gss_release_buffer(&ignore_min_stat, &out_tok);
423         printerr(1, "finished handling null request\n");
424         return;
425
426 out_err:
427         if (ctx != GSS_C_NO_CONTEXT)
428                 gss_delete_sec_context(&ignore_min_stat, &ctx, &ignore_out_tok);
429         send_response(f, &in_handle, &in_tok, maj_stat, min_stat,
430                         &null_token, &null_token);
431         goto out;
432 }