down-before-busy-20040723
[openafs.git] / src / WINNT / afsd / afsicf.cpp
1 /*
2
3 Copyright 2004 by the Massachusetts Institute of Technology
4
5 All rights reserved.
6
7 Permission to use, copy, modify, and distribute this software and its
8 documentation for any purpose and without fee is hereby granted,
9 provided that the above copyright notice appear in all copies and that
10 both that copyright notice and this permission notice appear in
11 supporting documentation, and that the name of the Massachusetts
12 Institute of Technology (M.I.T.) not be used in advertising or publicity
13 pertaining to distribution of the software without specific, written
14 prior permission.
15
16 M.I.T. DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING
17 ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL
18 M.I.T. BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR
19 ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
20 WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,
21 ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS
22 SOFTWARE.
23
24 */
25
26 #define _WIN32_DCOM
27 #include <windows.h>
28 #include <netfw.h>
29 #include <objbase.h>
30 #include <oleauto.h>
31 #include "afsicf.h"
32
33 //#define TESTMAIN
34
35 #ifdef TESTMAIN
36 #include<stdio.h>
37 #pragma comment(lib,"ole32.lib")
38 #pragma comment(lib,"oleaut32.lib")
39 #define DEBUGOUT(x) printf x
40 #else
41 #define DEBUGOUT(x)
42 #endif
43
44 /* an IPv4, enabled port with global scope */
45 struct global_afs_port_type {
46     LPWSTR      name;
47         LONG    port;
48         NET_FW_IP_PROTOCOL protocol;
49 };
50
51 typedef struct global_afs_port_type global_afs_port_t;
52
53 global_afs_port_t afs_clientPorts[] = {
54         { L"AFS CacheManager Callback (UDP)", 7001, NET_FW_IP_PROTOCOL_UDP },
55         { L"AFS CacheManager Callback (TCP)", 7001, NET_FW_IP_PROTOCOL_TCP }
56 };
57
58 global_afs_port_t afs_serverPorts[] = {
59         { L"AFS File Server (UDP)", 7000, NET_FW_IP_PROTOCOL_UDP },
60         { L"AFS File Server (TCP)", 7000, NET_FW_IP_PROTOCOL_TCP },
61         { L"AFS User & Group Database (UDP)", 7002, NET_FW_IP_PROTOCOL_UDP },
62         { L"AFS User & Group Database (TCP)", 7002, NET_FW_IP_PROTOCOL_TCP },
63         { L"AFS Volume Location Database (UDP)", 7003, NET_FW_IP_PROTOCOL_UDP },
64         { L"AFS Volume Location Database (TCP)", 7003, NET_FW_IP_PROTOCOL_TCP },
65         { L"AFS/Kerberos Authentication (UDP)", 7004, NET_FW_IP_PROTOCOL_UDP },
66         { L"AFS/Kerberos Authentication (TCP)", 7004, NET_FW_IP_PROTOCOL_TCP },
67         { L"AFS Volume Mangement (UDP)", 7005, NET_FW_IP_PROTOCOL_UDP },
68         { L"AFS Volume Mangement (TCP)", 7005, NET_FW_IP_PROTOCOL_TCP },
69         { L"AFS Error Interpretation (UDP)", 7006, NET_FW_IP_PROTOCOL_UDP },
70         { L"AFS Error Interpretation (TCP)", 7006, NET_FW_IP_PROTOCOL_TCP },
71         { L"AFS Basic Overseer (UDP)", 7007, NET_FW_IP_PROTOCOL_UDP },
72         { L"AFS Basic Overseer (TCP)", 7007, NET_FW_IP_PROTOCOL_TCP },
73         { L"AFS Server-to-server Updater (UDP)", 7008, NET_FW_IP_PROTOCOL_UDP },
74         { L"AFS Server-to-server Updater (TCP)", 7008, NET_FW_IP_PROTOCOL_TCP },
75         { L"AFS Remote Cache Manager (UDP)", 7009, NET_FW_IP_PROTOCOL_UDP },
76         { L"AFS Remote Cache Manager (TCP)", 7009, NET_FW_IP_PROTOCOL_TCP }
77 };
78
79 HRESULT icf_OpenFirewallProfile(INetFwProfile ** fwProfile) {
80     HRESULT hr = S_OK;
81     INetFwMgr* fwMgr = NULL;
82     INetFwPolicy* fwPolicy = NULL;
83
84     *fwProfile = NULL;
85
86     // Create an instance of the firewall settings manager.
87     hr = CoCreateInstance(
88             __uuidof(NetFwMgr),
89             NULL,
90             CLSCTX_INPROC_SERVER,
91             __uuidof(INetFwMgr),
92             reinterpret_cast<void**>(static_cast<INetFwMgr**>(&fwMgr))
93             );
94     if (FAILED(hr))
95     {
96                 DEBUGOUT(("Can't create fwMgr\n"));
97         goto error;
98     }
99
100     // Retrieve the local firewall policy.
101     hr = fwMgr->get_LocalPolicy(&fwPolicy);
102     if (FAILED(hr))
103     {
104                 DEBUGOUT(("Cant get local policy\n"));
105         goto error;
106     }
107
108     // Retrieve the firewall profile currently in effect.
109     hr = fwPolicy->get_CurrentProfile(fwProfile);
110     if (FAILED(hr))
111     {
112                 DEBUGOUT(("Can't get current profile\n"));
113         goto error;
114     }
115
116 error:
117
118     // Release the local firewall policy.
119     if (fwPolicy != NULL)
120     {
121         fwPolicy->Release();
122     }
123
124     // Release the firewall settings manager.
125     if (fwMgr != NULL)
126     {
127         fwMgr->Release();
128     }
129
130     return hr;
131 }
132
133 HRESULT icf_CheckAndAddPorts(INetFwProfile * fwProfile, global_afs_port_t * ports, int nPorts) {
134         INetFwOpenPorts * fwPorts = NULL;
135         INetFwOpenPort * fwPort = NULL;
136         HRESULT hr;
137         HRESULT rhr = S_OK; /* return value */
138
139         hr = fwProfile->get_GloballyOpenPorts(&fwPorts);
140         if (FAILED(hr)) {
141                 // Abort!
142                 DEBUGOUT(("Can't get globallyOpenPorts\n"));
143                 rhr = hr;
144                 goto cleanup;
145         }
146
147         // go through the supplied ports
148         for (int i=0; i<nPorts; i++) {
149                 VARIANT_BOOL vbEnabled;
150                 BSTR bstName = NULL;
151                 BOOL bCreate = FALSE;
152                 fwPort = NULL;
153
154                 hr = fwPorts->Item(ports[i].port, ports[i].protocol, &fwPort);
155                 if (SUCCEEDED(hr)) {
156                         DEBUGOUT(("Found port for %S\n",ports[i].name));
157             hr = fwPort->get_Enabled(&vbEnabled);
158                         if (SUCCEEDED(hr)) {
159                                 if ( vbEnabled == VARIANT_FALSE ) {
160                                         hr = fwPort->put_Enabled(VARIANT_TRUE);
161                                         if (FAILED(hr)) {
162                                                 // failed. Mark as failure. Don't try to create the port either.
163                                                 rhr = hr;
164                                         }
165                                 } // else we are fine
166                         } else {
167                 // Something is wrong with the port.
168                                 // We try to create a new one thus overriding this faulty one.
169                                 bCreate = TRUE;
170                         }
171                         fwPort->Release();
172                         fwPort = NULL;
173                 } else if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) {
174                         DEBUGOUT(("Port not found for %S\n", ports[i].name));
175                         bCreate = TRUE;
176                 }
177
178                 if (bCreate) {
179                         DEBUGOUT(("Trying to create port %S\n",ports[i].name));
180                         hr = CoCreateInstance(
181                                 __uuidof(NetFwOpenPort),
182                                 NULL,
183                                 CLSCTX_INPROC_SERVER,
184                                 __uuidof(INetFwOpenPort),
185                                 reinterpret_cast<void**>
186                                         (static_cast<INetFwOpenPort**>(&fwPort))
187                                 );
188
189                         if (FAILED(hr)) {
190                                 DEBUGOUT(("Can't create port\n"));
191                 rhr = hr;
192                         } else {
193                                 DEBUGOUT(("Created port\n"));
194                                 hr = fwPort->put_IpVersion( NET_FW_IP_VERSION_ANY );
195                                 if (FAILED(hr)) {
196                                         DEBUGOUT(("Can't set IpVersion\n"));
197                                         rhr = hr;
198                                         goto abandon_port;
199                                 }
200
201                                 hr = fwPort->put_Port( ports[i].port );
202                                 if (FAILED(hr)) {
203                                         DEBUGOUT(("Can't set Port\n"));
204                                         rhr = hr;
205                                         goto abandon_port;
206                                 }
207
208                                 hr = fwPort->put_Protocol( ports[i].protocol );
209                                 if (FAILED(hr)) {
210                                         DEBUGOUT(("Can't set Protocol\n"));
211                                         rhr = hr;
212                                         goto abandon_port;
213                                 }
214
215                                 hr = fwPort->put_Scope( NET_FW_SCOPE_ALL );
216                                 if (FAILED(hr)) {
217                                         DEBUGOUT(("Can't set Scope\n"));
218                                         rhr = hr;
219                                         goto abandon_port;
220                                 }
221
222                                 bstName = SysAllocString( ports[i].name );
223
224                                 if (SysStringLen(bstName) == 0) {
225                                         rhr = E_OUTOFMEMORY;
226                                 } else {
227                                         hr = fwPort->put_Name( bstName );
228                                         if (FAILED(hr)) {
229                                                 DEBUGOUT(("Can't set Name\n"));
230                                                 rhr = hr;
231                                                 SysFreeString( bstName );
232                                                 goto abandon_port;
233                                         }
234                                 }
235
236                                 SysFreeString( bstName );
237
238                                 hr = fwPorts->Add( fwPort );
239                                 if (FAILED(hr)) {
240                                         DEBUGOUT(("Can't add port\n"));
241                                         rhr = hr;
242                                 } else
243                                         DEBUGOUT(("Added port\n"));
244
245 abandon_port:
246                                 fwPort->Release();
247                         }
248                 }
249         } // loop through ports
250
251         fwPorts->Release();
252
253 cleanup:
254
255         if (fwPorts != NULL)
256                 fwPorts->Release();
257
258         return rhr;
259 }
260
261 long icf_CheckAndAddAFSPorts(int portset) {
262         HRESULT hr;
263         BOOL coInitialized = FALSE;
264         INetFwProfile * fwProfile = NULL;
265         global_afs_port_t * ports;
266         int nports;
267         long code = 0;
268
269         if (portset == AFS_PORTSET_CLIENT) {
270                 ports = afs_clientPorts;
271                 nports = sizeof(afs_clientPorts) / sizeof(*afs_clientPorts);
272         } else if (portset == AFS_PORTSET_SERVER) {
273                 ports = afs_serverPorts;
274                 nports = sizeof(afs_serverPorts) / sizeof(*afs_serverPorts);
275         } else
276                 return 1; /* Invalid port set */
277
278         hr = CoInitializeEx(
279         NULL,
280         COINIT_APARTMENTTHREADED | COINIT_DISABLE_OLE1DDE
281         );
282
283         if (SUCCEEDED(hr) || RPC_E_CHANGED_MODE == hr)
284     {
285        coInitialized = TRUE;
286     }
287         // not necessarily catastrophic if the call failed.  We'll try to
288         // continue as if it succeeded.
289
290     hr = icf_OpenFirewallProfile(&fwProfile);
291         if (FAILED(hr)) {
292                 // Ok. That didn't work.  This could be because the machine we
293                 // are running on doesn't have Windows Firewall.  We'll return
294                 // a failure to the caller, which shouldn't be taken to mean
295                 // it's catastrophic.
296                 DEBUGOUT(("Can't open Firewall profile\n"));
297                 code = 1;
298                 goto cleanup;
299         }
300
301         // Now that we have a firewall profile, we can start checking
302         // and adding the ports that we want.
303         hr = icf_CheckAndAddPorts(fwProfile, ports, nports);
304         if (FAILED(hr))
305                 code = 1;
306
307 cleanup:
308         if (coInitialized) {
309                 CoUninitialize();
310         }
311
312         return code;
313 }
314
315
316 #ifdef TESTMAIN
317 int main(int argc, char **argv) {
318         printf("Starting...\n");
319     if (icf_CheckAndAddAFSPorts(AFS_PORTSET_CLIENT))
320                 printf("Failed\n");
321         else
322                 printf("Succeeded\n");
323         printf("Done\n");
324         return 0;
325 }
326 #endif