use-private-xdr-in-kernel-to-avoid-conflicts-over-memory-ownership-20020608
[openafs.git] / src / rxkad / rxkad_common.c
1 /*
2  * Copyright 2000, International Business Machines Corporation and others.
3  * All Rights Reserved.
4  * 
5  * This software has been released under the terms of the IBM Public
6  * License.  For details, see the LICENSE file in the top-level source
7  * directory or online at http://www.openafs.org/dl/license10.html
8  */
9
10 /* The rxkad security object.  Routines used by both client and servers. */
11
12 #include <afsconfig.h>
13 #ifdef KERNEL
14 #include "../afs/param.h"
15 #else
16 #include <afs/param.h>
17 #endif
18
19 RCSID("$Header$");
20
21 #ifdef KERNEL
22 #ifndef UKERNEL
23 #include "../afs/stds.h"
24 #include "../afs/afs_osi.h"
25 #ifdef  AFS_AIX_ENV
26 #include "../h/systm.h"
27 #endif
28 #include "../h/types.h"
29 #include "../h/time.h"
30 #ifndef AFS_LINUX22_ENV
31 #include "../rpc/types.h"
32 #include "../rx/xdr.h"
33 #endif /* AFS_LINUX22_ENV */
34 #else /* !UKERNEL */
35 #include "../afs/sysincludes.h"
36 #include "../afs/afsincludes.h"
37 #endif /* !UKERNEL */
38 #include "../rx/rx.h"
39
40 #else /* KERNEL */
41 #include <afs/stds.h>
42 #include <sys/types.h>
43 #include <time.h>
44 #ifdef AFS_NT40_ENV
45 #include <winsock2.h>
46 #ifdef AFS_PTHREAD_ENV
47 #define RXKAD_STATS_DECLSPEC __declspec(dllexport)
48 #endif
49 #else
50 #include <netinet/in.h>
51 #endif
52 #include <rx/rx.h>
53 #include <rx/xdr.h>
54 #ifdef HAVE_STRING_H
55 #include <string.h>
56 #else
57 #ifdef HAVE_STRINGS_H
58 #include <strings.h>
59 #endif
60 #endif
61
62 #endif /* KERNEL */
63
64 #include "private_data.h"
65 #define XPRT_RXKAD_COMMON
66
67 char *rxi_Alloc();
68
69 #ifndef afs_max
70 #define afs_max(a,b)    ((a) < (b)? (b) : (a))
71 #endif /* afs_max */
72
73 #ifndef KERNEL
74 #define osi_Time() time(0)
75 #endif
76 struct rxkad_stats rxkad_stats = {0};
77
78 /* this call sets up an endpoint structure, leaving it in *network* byte
79  * order so that it can be used quickly for encryption.
80  */
81 rxkad_SetupEndpoint(aconnp, aendpointp)
82   IN struct rx_connection *aconnp;
83   OUT struct rxkad_endpoint *aendpointp;
84 {
85     register afs_int32 i;
86
87     aendpointp->cuid[0] = htonl(aconnp->epoch);
88     i = aconnp->cid & RX_CIDMASK;
89     aendpointp->cuid[1] = htonl(i);
90     aendpointp->cksum = 0;              /* used as cksum only in chal resp. */
91     aendpointp->securityIndex = htonl(aconnp->securityIndex);
92     return 0;
93 }
94
95 /* setup xor information based on session key */
96 rxkad_DeriveXORInfo(aconnp, aschedule, aivec, aresult)
97   IN struct rx_connection *aconnp;
98   IN fc_KeySchedule *aschedule;
99   IN char *aivec;
100   OUT char *aresult;
101 {
102     struct rxkad_endpoint tendpoint;
103     afs_uint32 xor[2];
104
105     rxkad_SetupEndpoint(aconnp, &tendpoint);
106     memcpy((void *)xor, aivec, 2*sizeof(afs_int32));
107     fc_cbc_encrypt(&tendpoint, &tendpoint, sizeof(tendpoint),
108                    aschedule, xor, ENCRYPT);
109     memcpy(aresult, ((char *)&tendpoint) + sizeof(tendpoint) - ENCRYPTIONBLOCKSIZE, ENCRYPTIONBLOCKSIZE);
110     return 0;
111 }
112
113 /* rxkad_CksumChallengeResponse - computes a checksum of the components of a
114  * challenge response packet (which must be unencrypted and in network order).
115  * The endpoint.cksum field is omitted and treated as zero.  The cksum is
116  * returned in network order. */
117
118 afs_uint32 rxkad_CksumChallengeResponse (v2r)
119   IN struct rxkad_v2ChallengeResponse *v2r;
120 {
121     int i;
122     afs_uint32 cksum;
123     u_char *cp = (u_char *)v2r;
124     afs_uint32 savedCksum = v2r->encrypted.endpoint.cksum;
125
126     v2r->encrypted.endpoint.cksum = 0;
127
128     /* this function captured from budb/db_hash.c */
129     cksum = 1000003;
130     for (i=0; i<sizeof(*v2r); i++)
131         cksum = (*cp++) + cksum * 0x10204081;
132
133     v2r->encrypted.endpoint.cksum = savedCksum;
134     return htonl(cksum);
135 }
136
137 void rxkad_SetLevel(conn, level)
138   struct rx_connection *conn;
139   rxkad_level           level;
140 {
141     if (level == rxkad_auth) {
142         rx_SetSecurityHeaderSize (conn, 4);
143         rx_SetSecurityMaxTrailerSize (conn, 4);
144     }
145     else if (level == rxkad_crypt) {
146         rx_SetSecurityHeaderSize (conn, 8);
147         rx_SetSecurityMaxTrailerSize (conn, 8); /* XXX was 7, but why screw with 
148                                                    unaligned accesses? */
149     }
150 }
151
152 /* returns a short integer in host byte order representing a good checksum of
153  * the packet header.
154  */
155 static afs_int32 ComputeSum(apacket, aschedule, aivec)
156 struct rx_packet *apacket;
157 afs_int32 *aivec;
158 fc_KeySchedule *aschedule; {
159     afs_uint32 word[2];
160     register afs_uint32 t;
161
162     t = apacket->header.callNumber;
163     word[0] = htonl(t);
164     /* note that word [1] includes the channel # */
165     t = ((apacket->header.cid & 0x3) << 30)
166             | ((apacket->header.seq & 0x3fffffff));
167     word[1] = htonl(t);
168     /* XOR in the ivec from the per-endpoint encryption */
169     word[0] ^= aivec[0];
170     word[1] ^= aivec[1];
171     /* encrypts word as if it were a character string */
172     fc_ecb_encrypt(word, word, aschedule, ENCRYPT);
173     t = ntohl(word[1]);
174     t = (t >> 16) & 0xffff;
175     if (t == 0) t = 1;  /* so that 0 means don't care */
176     return t;
177 }
178
179
180 static afs_int32 FreeObject (aobj)
181   IN struct rx_securityClass *aobj;
182 {   struct rxkad_cprivate *tcp;         /* both structs start w/ type field */
183
184     if (aobj->refCount > 0) return 0;   /* still in use */
185     tcp = (struct rxkad_cprivate *)aobj->privateData;
186     rxi_Free(aobj, sizeof(struct rx_securityClass));
187     if (tcp->type & rxkad_client) {
188         rxi_Free(tcp, sizeof(struct rxkad_cprivate));
189     }
190     else if (tcp->type & rxkad_server) {
191         rxi_Free(tcp, sizeof(struct rxkad_sprivate));
192     }
193     else { return RXKADINCONSISTENCY; } /* unknown type */
194     LOCK_RXKAD_STATS
195     rxkad_stats.destroyObject++;
196     UNLOCK_RXKAD_STATS
197     return 0;
198 }
199
200 /* rxkad_Close - called by rx with the security class object as a parameter
201  * when a security object is to be discarded */
202
203 rxs_return_t rxkad_Close (aobj)
204   IN struct rx_securityClass *aobj;
205 {
206     afs_int32 code;
207     aobj->refCount--;
208     code = FreeObject (aobj);
209     return code;
210 }
211
212 /* either: called to (re)create a new connection. */
213
214 rxs_return_t rxkad_NewConnection (aobj, aconn)
215   struct rx_securityClass *aobj;
216   struct rx_connection    *aconn;
217 {
218     if (aconn->securityData)
219         return RXKADINCONSISTENCY;      /* already allocated??? */
220
221     if (rx_IsServerConn(aconn)) {
222         int size = sizeof(struct rxkad_sconn);
223         aconn->securityData = (char *) rxi_Alloc (size);
224         memset(aconn->securityData, 0, size); /* initialize it conveniently */
225     }
226     else { /* client */
227         struct rxkad_cprivate *tcp;
228         struct rxkad_cconn *tccp;
229         int size = sizeof(struct rxkad_cconn);
230         tccp = (struct rxkad_cconn *) rxi_Alloc (size);
231         aconn->securityData = (char *) tccp;
232         memset(aconn->securityData, 0, size); /* initialize it conveniently */
233         tcp = (struct rxkad_cprivate *) aobj->privateData;
234         if (!(tcp->type & rxkad_client)) return RXKADINCONSISTENCY;
235         rxkad_SetLevel(aconn, tcp->level); /* set header and trailer sizes */
236         rxkad_AllocCID(aobj, aconn);    /* CHANGES cid AND epoch!!!! */
237         rxkad_DeriveXORInfo(aconn, tcp->keysched, tcp->ivec, tccp->preSeq);
238         LOCK_RXKAD_STATS
239         rxkad_stats.connections[rxkad_LevelIndex(tcp->level)]++;
240         UNLOCK_RXKAD_STATS
241     }
242
243     aobj->refCount++;                   /* attached connection */
244     return 0;
245 }
246
247 /* either: called to destroy a connection. */
248
249 rxs_return_t rxkad_DestroyConnection (aobj, aconn)
250   struct rx_securityClass *aobj;
251   struct rx_connection    *aconn;
252 {
253     if (rx_IsServerConn(aconn)) {
254         struct rxkad_sconn *sconn;
255         struct rxkad_serverinfo *rock;
256         sconn = (struct rxkad_sconn *)aconn->securityData;
257         if (sconn) {
258             aconn->securityData = 0;
259             LOCK_RXKAD_STATS
260             if (sconn->authenticated)
261                 rxkad_stats.destroyConn[rxkad_LevelIndex(sconn->level)]++;
262             else rxkad_stats.destroyUnauth++;
263             UNLOCK_RXKAD_STATS
264             rock = sconn->rock;
265             if (rock) rxi_Free (rock, sizeof(struct rxkad_serverinfo));
266             rxi_Free (sconn, sizeof(struct rxkad_sconn));
267         }
268         else {
269             LOCK_RXKAD_STATS
270             rxkad_stats.destroyUnused++;
271             UNLOCK_RXKAD_STATS
272         }
273     }
274     else {                              /* client */
275         struct rxkad_cconn *cconn;
276         struct rxkad_cprivate *tcp;
277         cconn = (struct rxkad_cconn *)aconn->securityData;
278         tcp = (struct rxkad_cprivate *) aobj->privateData;
279         if (!(tcp->type & rxkad_client)) return RXKADINCONSISTENCY;
280         if (cconn) {
281             aconn->securityData = 0;
282             rxi_Free (cconn, sizeof(struct rxkad_cconn));
283         }
284         LOCK_RXKAD_STATS
285         rxkad_stats.destroyClient++;
286         UNLOCK_RXKAD_STATS
287     }
288     aobj->refCount--;                   /* decrement connection counter */
289     if (aobj->refCount <= 0) {
290         afs_int32 code;
291         code = FreeObject (aobj);
292         if (code) return code;
293     }
294     return 0;
295 }
296
297 /* either: decode packet */
298
299 rxs_return_t rxkad_CheckPacket (aobj, acall, apacket)
300   struct rx_securityClass *aobj;
301   struct rx_call          *acall;
302   struct rx_packet        *apacket;
303 {   struct rx_connection  *tconn;
304     rxkad_level            level;
305     fc_KeySchedule *schedule;
306     fc_InitializationVector *ivec;
307     int len;
308     int nlen;
309     u_int word;                         /* so we get unsigned right-shift */
310     int checkCksum;
311     afs_int32 *preSeq;
312     afs_int32 code;
313
314     tconn = rx_ConnectionOf(acall);
315     len = rx_GetDataSize (apacket);
316     checkCksum = 0;                     /* init */
317     if (rx_IsServerConn(tconn)) {
318         struct rxkad_sconn *sconn;
319         sconn = (struct rxkad_sconn *) tconn->securityData;
320         if (rx_GetPacketCksum(apacket) != 0) sconn->cksumSeen = 1;
321         checkCksum = sconn->cksumSeen;
322         if (sconn && sconn->authenticated &&
323             (osi_Time() < sconn->expirationTime)) {
324             level = sconn->level;
325             LOCK_RXKAD_STATS
326             rxkad_stats.checkPackets[rxkad_StatIndex(rxkad_server, level)]++;
327             UNLOCK_RXKAD_STATS
328             sconn->stats.packetsReceived++;
329             sconn->stats.bytesReceived += len;
330             schedule = (fc_KeySchedule *)sconn->keysched;
331             ivec = (fc_InitializationVector *)sconn->ivec;
332         }
333         else {
334             LOCK_RXKAD_STATS
335             rxkad_stats.expired++;
336             UNLOCK_RXKAD_STATS
337             return RXKADEXPIRED;
338         }
339         preSeq = sconn->preSeq;
340     }
341     else {                              /* client connection */
342         struct rxkad_cconn *cconn;
343         struct rxkad_cprivate *tcp;
344         cconn = (struct rxkad_cconn *) tconn->securityData;
345         if (rx_GetPacketCksum(apacket) != 0) cconn->cksumSeen = 1;
346         checkCksum = cconn->cksumSeen;
347         tcp = (struct rxkad_cprivate *) aobj->privateData;
348         if (!(tcp->type & rxkad_client)) return RXKADINCONSISTENCY;
349         level = tcp->level;
350         LOCK_RXKAD_STATS
351         rxkad_stats.checkPackets[rxkad_StatIndex(rxkad_client, level)]++;
352         UNLOCK_RXKAD_STATS
353         cconn->stats.packetsReceived++;
354         cconn->stats.bytesReceived += len;
355         preSeq = cconn->preSeq;
356         schedule = (fc_KeySchedule *)tcp->keysched;
357         ivec = (fc_InitializationVector *)tcp->ivec;
358     }
359     
360     if (checkCksum) {
361         code = ComputeSum(apacket, schedule, preSeq);
362         if (code != rx_GetPacketCksum(apacket))
363             return RXKADSEALEDINCON;
364     }
365
366     switch (level) {
367       case rxkad_clear: return 0;       /* shouldn't happen */
368       case rxkad_auth:
369         rx_Pullup(apacket, 8);  /* the following encrypts 8 bytes only */
370         fc_ecb_encrypt (rx_DataOf(apacket), rx_DataOf(apacket),
371                         schedule, DECRYPT);
372         break;
373       case rxkad_crypt:
374         code = rxkad_DecryptPacket (tconn, schedule, ivec, len, apacket);
375         if (code) return code;
376         break;
377     }
378     word = ntohl(rx_GetInt32(apacket,0)); /* get first sealed word */
379     if ((word >> 16) !=
380         ((apacket->header.seq ^ apacket->header.callNumber) & 0xffff))
381         return RXKADSEALEDINCON;
382     nlen = word & 0xffff;               /* get real user data length */
383
384     /* The sealed length should be no larger than the initial length, since the  
385      * reverse (round-up) occurs in ...PreparePacket */
386     if (nlen > len)                     
387       return RXKADDATALEN;              
388     rx_SetDataSize (apacket, nlen);
389     return 0;
390 }
391
392 /* either: encode packet */
393
394 rxs_return_t rxkad_PreparePacket (aobj, acall, apacket)
395   struct rx_securityClass *aobj;
396   struct rx_call *acall;
397   struct rx_packet *apacket;
398 {
399     struct rx_connection *tconn;
400     rxkad_level         level;
401     fc_KeySchedule *schedule;
402     fc_InitializationVector *ivec;
403     int len;
404     int nlen;
405     int word;
406     afs_int32 code;
407     afs_int32 *preSeq;
408
409     tconn = rx_ConnectionOf(acall);
410     len = rx_GetDataSize (apacket);
411     if (rx_IsServerConn(tconn)) {
412         struct rxkad_sconn *sconn;
413         sconn = (struct rxkad_sconn *) tconn->securityData;
414         if (sconn && sconn->authenticated &&
415             (osi_Time() < sconn->expirationTime)) {
416             level = sconn->level;
417             LOCK_RXKAD_STATS
418             rxkad_stats.preparePackets[rxkad_StatIndex(rxkad_server, level)]++;
419             UNLOCK_RXKAD_STATS
420             sconn->stats.packetsSent++;
421             sconn->stats.bytesSent += len;
422             schedule = (fc_KeySchedule *)sconn->keysched;
423             ivec = (fc_InitializationVector *)sconn->ivec;
424         }
425         else {
426             LOCK_RXKAD_STATS
427             rxkad_stats.expired++;      /* this is a pretty unlikely path... */
428             UNLOCK_RXKAD_STATS
429             return RXKADEXPIRED;
430         }
431         preSeq = sconn->preSeq;
432     }
433     else {                              /* client connection */
434         struct rxkad_cconn *cconn;
435         struct rxkad_cprivate *tcp;
436         cconn = (struct rxkad_cconn *) tconn->securityData;
437         tcp = (struct rxkad_cprivate *) aobj->privateData;
438         if (!(tcp->type & rxkad_client)) return RXKADINCONSISTENCY;
439         level = tcp->level;
440         LOCK_RXKAD_STATS
441         rxkad_stats.preparePackets[rxkad_StatIndex(rxkad_client, level)]++;
442         UNLOCK_RXKAD_STATS
443         cconn->stats.packetsSent++;
444         cconn->stats.bytesSent += len;
445         preSeq = cconn->preSeq;
446         schedule = (fc_KeySchedule *)tcp->keysched;
447         ivec = (fc_InitializationVector *)tcp->ivec;
448     }
449
450     /* compute upward compatible checksum */
451     rx_SetPacketCksum(apacket, ComputeSum(apacket, schedule, preSeq));
452     if (level == rxkad_clear) return 0;
453
454     len = rx_GetDataSize (apacket);
455     word = (((apacket->header.seq ^ apacket->header.callNumber)
456              & 0xffff) << 16) | (len & 0xffff);
457     rx_PutInt32(apacket,0, htonl(word));   
458
459     switch (level) {
460       case rxkad_clear: return 0;       /* shouldn't happen */
461       case rxkad_auth:
462         nlen = afs_max (ENCRYPTIONBLOCKSIZE,
463                     len + rx_GetSecurityHeaderSize(tconn));
464         if (nlen > (len + rx_GetSecurityHeaderSize(tconn))) {
465           rxi_RoundUpPacket(apacket, nlen - (len + rx_GetSecurityHeaderSize(tconn)));
466         }
467         rx_Pullup(apacket, 8);  /* the following encrypts 8 bytes only */
468         fc_ecb_encrypt (rx_DataOf(apacket), rx_DataOf(apacket),
469                         schedule, ENCRYPT);
470         break;
471       case rxkad_crypt:
472         nlen = round_up_to_ebs(len + rx_GetSecurityHeaderSize(tconn));
473         if (nlen > (len + rx_GetSecurityHeaderSize(tconn))) {
474           rxi_RoundUpPacket(apacket, nlen - (len + rx_GetSecurityHeaderSize(tconn)));
475         }
476         code = rxkad_EncryptPacket (tconn, schedule, ivec, nlen, apacket);
477         if (code) return code;
478         break;
479     }
480     rx_SetDataSize (apacket, nlen);
481     return 0;
482 }
483
484 /* either: return connection stats */
485
486 rxs_return_t rxkad_GetStats (aobj, aconn, astats)
487   IN struct rx_securityClass *aobj;
488   IN struct rx_connection *aconn;
489   OUT struct rx_securityObjectStats *astats;
490 {
491     astats->type = 3;
492     astats->level = ((struct rxkad_cprivate *)aobj->privateData)->level;
493     if (!aconn->securityData) {
494         astats->flags |= 1;
495         return 0;
496     }
497     if (rx_IsServerConn(aconn)) {
498         struct rxkad_sconn *sconn;
499         sconn = (struct rxkad_sconn *) aconn->securityData;
500         astats->level = sconn->level;
501         if (sconn->authenticated) astats->flags |= 2;
502         if (sconn->cksumSeen) astats->flags |= 8;
503         astats->expires = sconn->expirationTime;
504         astats->bytesReceived = sconn->stats.bytesReceived;
505         astats->packetsReceived = sconn->stats.packetsReceived;
506         astats->bytesSent = sconn->stats.bytesSent;
507         astats->packetsSent = sconn->stats.packetsSent;
508     }
509     else { /* client connection */
510         struct rxkad_cconn *cconn;
511         cconn = (struct rxkad_cconn *) aconn->securityData;
512         if (cconn->cksumSeen) astats->flags |= 8;
513         astats->bytesReceived = cconn->stats.bytesReceived;
514         astats->packetsReceived = cconn->stats.packetsReceived;
515         astats->bytesSent = cconn->stats.bytesSent;
516         astats->packetsSent = cconn->stats.packetsSent;
517     }
518     return 0;
519 }