more rx/tests cleanups
[openafs.git] / src / rx / rx_xmit_nt.c
1 /*
2  * Copyright 2000, International Business Machines Corporation and others.
3  * All Rights Reserved.
4  *
5  * This software has been released under the terms of the IBM Public
6  * License.  For details, see the LICENSE file in the top-level source
7  * directory or online at http://www.openafs.org/dl/license10.html
8  */
9
10 /* NT does not have uio structs, so we roll our own sendmsg and recvmsg.
11  *
12  * The dangerous part of this code is that it assumes that iovecs 0 and 1
13  * are contiguous and that all of 0 is used before any of 1.
14  * This is true if rx_packets are being sent, so we should be ok.
15  */
16
17 #include <afsconfig.h>
18 #include <afs/param.h>
19
20 #if defined(AFS_NT40_ENV)
21 # include <roken.h>
22 # include <winsock2.h>
23 # if (_WIN32_WINNT < 0x0501)
24 #  undef _WIN32_WINNT
25 #  define _WIN32_WINNT 0x0501
26 # endif
27 # include <mswsock.h>
28
29 # if (_WIN32_WINNT < 0x0600)
30 /*
31  * WSASendMsg -- send data to a specific destination, with options, using
32  *    overlapped I/O where applicable.
33  *
34  * Valid flags for dwFlags parameter:
35  *    MSG_DONTROUTE
36  *    MSG_PARTIAL (a.k.a. MSG_EOR) (only for non-stream sockets)
37  *    MSG_OOB (only for stream style sockets) (NYI)
38  *
39  * Caller must provide either lpOverlapped or lpCompletionRoutine
40  * or neither (both NULL).
41  */
42 typedef
43 INT
44 (PASCAL FAR * LPFN_WSASENDMSG) (
45     IN SOCKET s,
46     IN LPWSAMSG lpMsg,
47     IN DWORD dwFlags,
48     __out_opt LPDWORD lpNumberOfBytesSent,
49     IN LPWSAOVERLAPPED lpOverlapped OPTIONAL,
50     IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine OPTIONAL
51     );
52
53 # define WSAID_WSASENDMSG /* a441e712-754f-43ca-84a7-0dee44cf606d */ \
54     {0xa441e712,0x754f,0x43ca,{0x84,0xa7,0x0d,0xee,0x44,0xcf,0x60,0x6d}}
55 #endif /* AFS_NT40_ENV */
56
57 #include "rx.h"
58 #include "rx_packet.h"
59 #include "rx_globals.h"
60 #include "rx_xmit_nt.h"
61 #include <malloc.h>
62 #include <errno.h>
63
64
65 /*
66  * WSASendMsg is only supported on Vista and above
67  * Neither function is part of the public WinSock API
68  * and therefore the function pointers must be
69  * obtained via WSAIoctl()
70  */
71 static LPFN_WSARECVMSG pWSARecvMsg = NULL;
72 static LPFN_WSASENDMSG pWSASendMsg = NULL;
73
74 void
75 rxi_xmit_init(osi_socket s)
76 {
77     int rc;
78     GUID WSARecvMsg_GUID = WSAID_WSARECVMSG;
79     GUID WSASendMsg_GUID = WSAID_WSASENDMSG;
80     DWORD dwIn, dwOut, NumberOfBytes;
81
82     rc = WSAIoctl( s, SIO_GET_EXTENSION_FUNCTION_POINTER,
83                    &WSARecvMsg_GUID, sizeof(WSARecvMsg_GUID),
84                    &pWSARecvMsg, sizeof(pWSARecvMsg),
85                    &NumberOfBytes, NULL, NULL);
86
87     rc = WSAIoctl( s, SIO_GET_EXTENSION_FUNCTION_POINTER,
88                    &WSASendMsg_GUID, sizeof(WSASendMsg_GUID),
89                    &pWSASendMsg, sizeof(pWSASendMsg),
90                    &NumberOfBytes, NULL, NULL);
91
92     /* Turn on UDP PORT_UNREACHABLE messages */
93     dwIn = 1;
94     rc = WSAIoctl( s, SIO_UDP_CONNRESET,
95                    &dwIn, sizeof(dwIn),
96                    &dwOut, sizeof(dwOut),
97                    &NumberOfBytes, NULL, NULL);
98
99     /* Turn on UDP CIRCULAR QUEUEING messages */
100     dwIn = 1;
101     rc = WSAIoctl( s, SIO_ENABLE_CIRCULAR_QUEUEING,
102                    &dwIn, sizeof(dwIn),
103                    &dwOut, sizeof(dwOut),
104                    &NumberOfBytes, NULL, NULL);
105 }
106
107 int
108 recvmsg(osi_socket socket, struct msghdr *msgP, int flags)
109 {
110     int code;
111
112     if (pWSARecvMsg) {
113         WSAMSG wsaMsg;
114         DWORD  dwBytes;
115
116         wsaMsg.name = (LPSOCKADDR)(msgP->msg_name);
117         wsaMsg.namelen = (INT)(msgP->msg_namelen);
118
119         wsaMsg.lpBuffers = (LPWSABUF) msgP->msg_iov;
120         wsaMsg.dwBufferCount = msgP->msg_iovlen;
121         wsaMsg.Control.len = 0;
122         wsaMsg.Control.buf = NULL;
123         wsaMsg.dwFlags = flags;
124
125         code = pWSARecvMsg(socket, &wsaMsg, &dwBytes, NULL, NULL);
126         if (code == 0) {
127             /* success - return the number of bytes read */
128             code = (int)dwBytes;
129         } else {
130             /* error - set errno and return -1 */
131             if (code == SOCKET_ERROR)
132                 code = WSAGetLastError();
133             if (code == WSAEWOULDBLOCK || code == WSAECONNRESET)
134                 errno = WSAEWOULDBLOCK;
135             else
136                 errno = EIO;
137             code = -1;
138         }
139     } else {
140         char rbuf[RX_MAX_PACKET_SIZE];
141         int size;
142         int off, i, n;
143         int allocd = 0;
144
145         size = rx_maxJumboRecvSize;
146         code =
147             recvfrom((SOCKET) socket, rbuf, size, flags,
148                       (struct sockaddr *)(msgP->msg_name), &(msgP->msg_namelen));
149
150         if (code > 0) {
151             size = code;
152
153             for (off = i = 0; size > 0 && i < msgP->msg_iovlen; i++) {
154                 if (msgP->msg_iov[i].iov_len) {
155                     if (msgP->msg_iov[i].iov_len < size) {
156                         n = msgP->msg_iov[i].iov_len;
157                     } else {
158                         n = size;
159                     }
160                     memcpy(msgP->msg_iov[i].iov_base, &rbuf[off], n);
161                     off += n;
162                     size -= n;
163                 }
164             }
165
166             /* Accounts for any we didn't copy in to iovecs. */
167             code -= size;
168         } else {
169             if (code == SOCKET_ERROR)
170                 code = WSAGetLastError();
171             if (code == WSAEWOULDBLOCK || code == WSAECONNRESET)
172                 errno = WSAEWOULDBLOCK;
173             else
174                 errno = EIO;
175             code = -1;
176         }
177     }
178
179     return code;
180 }
181
182 int
183 sendmsg(osi_socket socket, struct msghdr *msgP, int flags)
184 {
185     int code;
186
187     if (pWSASendMsg) {
188         WSAMSG wsaMsg;
189         DWORD  dwBytes;
190
191         wsaMsg.name = (LPSOCKADDR)(msgP->msg_name);
192         wsaMsg.namelen = (INT)(msgP->msg_namelen);
193
194         wsaMsg.lpBuffers = (LPWSABUF) msgP->msg_iov;
195         wsaMsg.dwBufferCount = msgP->msg_iovlen;
196         wsaMsg.Control.len = 0;
197         wsaMsg.Control.buf = NULL;
198         wsaMsg.dwFlags = 0;
199
200         code = pWSASendMsg(socket, &wsaMsg, flags, &dwBytes, NULL, NULL);
201         if (code == 0) {
202             /* success - return the number of bytes read */
203             code = (int)dwBytes;
204         } else {
205             /* error - set errno and return -1 */
206             if (code == SOCKET_ERROR)
207                 code = WSAGetLastError();
208             switch (code) {
209             case WSAEINPROGRESS:
210             case WSAENETRESET:
211             case WSAENOBUFS:
212                 errno = 0;
213                 break;
214             case WSAEWOULDBLOCK:
215             case WSAECONNRESET:
216                 errno = WSAEWOULDBLOCK;
217                 break;
218             case WSAEHOSTUNREACH:
219                 errno = WSAEHOSTUNREACH;
220                 break;
221             default:
222                 errno = EIO;
223                 break;
224             }
225             code = -1;
226         }
227     } else {
228         char buf[RX_MAX_PACKET_SIZE];
229         char *sbuf = buf;
230         int size, tmp;
231         int off, i, n;
232         int allocd = 0;
233
234         for (size = i = 0; i < msgP->msg_iovlen; i++)
235             size += msgP->msg_iov[i].iov_len;
236
237         if (msgP->msg_iovlen <= 2) {
238             sbuf = msgP->msg_iov[0].iov_base;
239         } else {
240             /* Pack data into array from iovecs */
241             tmp = size;
242             for (off = i = 0; tmp > 0 && i < msgP->msg_iovlen; i++) {
243                 if (msgP->msg_iov[i].iov_len > 0) {
244                     if (tmp > msgP->msg_iov[i].iov_len)
245                         n = msgP->msg_iov[i].iov_len;
246                     else
247                         n = tmp;
248                     memcpy(&sbuf[off], msgP->msg_iov[i].iov_base, n);
249                     off += n;
250                     tmp -= n;
251                 }
252             }
253         }
254
255         code =
256             sendto((SOCKET) socket, sbuf, size, flags,
257                     (struct sockaddr *)(msgP->msg_name), msgP->msg_namelen);
258         if (code == SOCKET_ERROR) {
259             code = WSAGetLastError();
260             switch (code) {
261             case WSAEINPROGRESS:
262             case WSAENETRESET:
263             case WSAENOBUFS:
264                 errno = 0;
265                 break;
266             case WSAEWOULDBLOCK:
267             case WSAECONNRESET:
268                 errno = WSAEWOULDBLOCK;
269                 break;
270             case WSAEHOSTUNREACH:
271                 errno = WSAEHOSTUNREACH;
272                 break;
273             default:
274                 errno = EIO;
275                 break;
276             }
277             code = -1;
278         } else {
279             if (code < size) {
280                 errno = EIO;
281                 code = -1;
282             }
283         }
284     }
285     return code;
286
287 }
288 #endif /* AFS_NT40_ENV */