1837302a7db33ec272e3ff57bfaec0eeb6518976
[openafs.git] / src / WINNT / afsd / afsicf.cpp
1 /*
2  * Copyright 2004 by the Massachusetts Institute of Technology
3  *
4  * All rights reserved.
5  *
6  * Permission to use, copy, modify, and distribute this software and its
7  * documentation for any purpose and without fee is hereby granted,
8  * provided that the above copyright notice appear in all copies and that
9  * both that copyright notice and this permission notice appear in
10  * supporting documentation, and that the name of the Massachusetts
11  * Institute of Technology (M.I.T.) not be used in advertising or publicity
12  * pertaining to distribution of the software without specific, written
13  * prior permission.
14  *
15  * M.I.T. DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING
16  * ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL
17  * M.I.T. BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR
18  * ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
19  * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,
20  * ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS
21  * SOFTWARE.
22  *
23  */
24
25 /*
26  * Copyright 2011 by Your File System, Inc.
27  */
28
29 #define _WIN32_DCOM
30 #include <windows.h>
31 #include <netfw.h>
32 #include <objbase.h>
33 #include <oleauto.h>
34 #include "afsicf.h"
35
36 #ifdef TESTMAIN
37 #include<stdio.h>
38 #pragma comment(lib,"ole32.lib")
39 #pragma comment(lib,"oleaut32.lib")
40 #define DEBUGOUT(x) printf(x)
41 #define DEBUGOUTW(x) wprintf(x)
42 #else
43 #define DEBUGOUT(x) OutputDebugString(x)
44 #define DEBUGOUTW(x) OutputDebugStringW(x)
45 #endif
46
47 /* an IPv4, enabled port with global scope */
48 struct global_afs_port_type {
49     LPWSTR      name;
50     LONG        n_port;
51     LPWSTR      str_port;
52     NET_FW_IP_PROTOCOL protocol;
53 };
54
55 typedef struct global_afs_port_type global_afs_port_t;
56
57 global_afs_port_t afs_clientPorts[] = {
58     { L"AFS CacheManager Callback (UDP)", 7001, L"7001", NET_FW_IP_PROTOCOL_UDP }
59 #ifdef AFS_TCP
60 ,   { L"AFS CacheManager Callback (TCP)", 7001, L"7001", NET_FW_IP_PROTOCOL_TCP }
61 #endif
62 };
63
64 global_afs_port_t afs_serverPorts[] = {
65     { L"AFS File Server (UDP)", 7000, L"7000", NET_FW_IP_PROTOCOL_UDP },
66 #ifdef AFS_TCP
67     { L"AFS File Server (TCP)", 7000, L"7000", NET_FW_IP_PROTOCOL_TCP },
68 #endif
69     { L"AFS User & Group Database (UDP)", 7002, L"7002", NET_FW_IP_PROTOCOL_UDP },
70 #ifdef AFS_TCP
71     { L"AFS User & Group Database (TCP)", 7002, L"7002", NET_FW_IP_PROTOCOL_TCP },
72 #endif
73     { L"AFS Volume Location Database (UDP)", 7003, L"7003", NET_FW_IP_PROTOCOL_UDP },
74 #ifdef AFS_TCP
75     { L"AFS Volume Location Database (TCP)", 7003, L"7003", NET_FW_IP_PROTOCOL_TCP },
76 #endif
77     { L"AFS/Kerberos Authentication (UDP)", 7004, L"7004", NET_FW_IP_PROTOCOL_UDP },
78 #ifdef AFS_TCP
79     { L"AFS/Kerberos Authentication (TCP)", 7004, L"7004", NET_FW_IP_PROTOCOL_TCP },
80 #endif
81     { L"AFS Volume Mangement (UDP)", 7005, L"7005", NET_FW_IP_PROTOCOL_UDP },
82 #ifdef AFS_TCP
83     { L"AFS Volume Mangement (TCP)", 7005, L"7005", NET_FW_IP_PROTOCOL_TCP },
84 #endif
85     { L"AFS Error Interpretation (UDP)", 7006, L"7006", NET_FW_IP_PROTOCOL_UDP },
86 #ifdef AFS_TCP
87     { L"AFS Error Interpretation (TCP)", 7006, L"7006", NET_FW_IP_PROTOCOL_TCP },
88 #endif
89     { L"AFS Basic Overseer (UDP)", 7007, L"7007", NET_FW_IP_PROTOCOL_UDP },
90 #ifdef AFS_TCP
91     { L"AFS Basic Overseer (TCP)", 7007, L"7007", NET_FW_IP_PROTOCOL_TCP },
92 #endif
93     { L"AFS Server-to-server Updater (UDP)", 7008, L"7008", NET_FW_IP_PROTOCOL_UDP },
94 #ifdef AFS_TCP
95     { L"AFS Server-to-server Updater (TCP)", 7008, L"7008", NET_FW_IP_PROTOCOL_TCP },
96 #endif
97     { L"AFS Remote Cache Manager (UDP)", 7009, L"7009", NET_FW_IP_PROTOCOL_UDP }
98 #ifdef AFS_TCP
99 ,   { L"AFS Remote Cache Manager (TCP)", 7009, L"7009", NET_FW_IP_PROTOCOL_TCP }
100 #endif
101 };
102
103 HRESULT icf_CheckAndAddPorts2(WCHAR * wServiceName, global_afs_port_t * ports, int nPorts)
104 {
105     INetFwPolicy2 *pNetFwPolicy2 = NULL;
106     INetFwRules *pFwRules = NULL;
107     INetFwRule *pFwRule = NULL;
108     WCHAR wFilename[1024] = L"C:\\Program Files\\OpenAFS\\Client\\Program\\afsd_service.exe";
109
110     long CurrentProfilesBitMask = 0;
111     int  i;
112
113 #ifndef TESTMAIN
114     GetModuleFileNameW(NULL, wFilename, 1024);
115 #endif
116
117     BSTR bstrRuleGroup = SysAllocString(L"OpenAFS Firewall Rules");
118     BSTR bstrRuleApplication = SysAllocString(wFilename);
119     BSTR bstrRuleService = SysAllocString(wServiceName);
120     BSTR bstrInterfaceTypes = SysAllocString(L"all");
121
122     HRESULT hrComInit = S_OK;
123     HRESULT hr = S_OK;
124
125     // Retrieve INetFwPolicy2
126     hr = CoCreateInstance( __uuidof(NetFwPolicy2),
127                            NULL,
128                            CLSCTX_INPROC_SERVER,
129                            __uuidof(INetFwPolicy2),
130                            (void**)&pNetFwPolicy2);
131     if (FAILED(hr))
132     {
133         DEBUGOUT(("Can't create NetFwPolicy2\n"));
134         goto Cleanup;
135     }
136
137     // Retrieve INetFwRules
138     hr = pNetFwPolicy2->get_Rules(&pFwRules);
139     if (FAILED(hr))
140     {
141         DEBUGOUT(("get_Rules failed\n"));
142         goto Cleanup;
143     }
144
145     if ( nPorts == 0 )
146         DEBUGOUT(("No port specified\n"));
147
148     for ( i=0; i < nPorts; i++)
149     {
150         BSTR bstrRuleName = SysAllocString(ports[i].name);
151         BSTR bstrRuleDescription = SysAllocString(ports[i].name);
152         BSTR bstrRuleLPorts = SysAllocString(ports[i].str_port);
153
154         hr = pFwRules->Item(bstrRuleName, &pFwRule);
155         if (FAILED(hr))
156         {
157             // Create a new Firewall Rule object.
158             hr = CoCreateInstance( __uuidof(NetFwRule),
159                                    NULL,
160                                    CLSCTX_INPROC_SERVER,
161                                    __uuidof(INetFwRule),
162                                    (void**)&pFwRule);
163             if (SUCCEEDED(hr))
164             {
165                 // Populate the Firewall Rule object
166                 pFwRule->put_Name(bstrRuleName);
167                 pFwRule->put_Description(bstrRuleDescription);
168                 pFwRule->put_ApplicationName(bstrRuleApplication);
169
170                 // Add the Firewall Rule
171                 hr = pFwRules->Add(pFwRule);
172                 if (FAILED(hr))
173                 {
174                     DEBUGOUT(("Advanced Firewall Rule Add failed\n"));
175                 }
176                 else
177                 {
178                     DEBUGOUT(("Advanced Firewall Rule Add successful\n"));
179
180                     //
181                     // Do not assign the service name to the rule.
182                     // Only specify the executable name. According to feedback
183                     // in openafs-info, the service name filter blocks the rule.
184                     //
185                     pFwRule->put_ServiceName(NULL);
186                     pFwRule->put_Protocol(ports[i].protocol);
187                     pFwRule->put_LocalPorts(bstrRuleLPorts);
188                     pFwRule->put_Grouping(bstrRuleGroup);
189                     pFwRule->put_Profiles(NET_FW_PROFILE2_ALL);
190                     pFwRule->put_Action(NET_FW_ACTION_ALLOW);
191                     pFwRule->put_Enabled(VARIANT_TRUE);
192                     pFwRule->put_EdgeTraversal(VARIANT_TRUE);
193                     pFwRule->put_InterfaceTypes(bstrInterfaceTypes);
194                 }
195             }
196             else
197             {
198                 DEBUGOUT(("CoCreateInstance INetFwRule failed\n"));
199             }
200         }
201         else
202         {
203             DEBUGOUT(("INetFwRule already exists\n"));
204
205             hr = pFwRule->put_ServiceName(NULL);
206             if (SUCCEEDED(hr))
207             {
208                 DEBUGOUT(("INetFwRule Service Name Updated\n"));
209             }
210
211             hr = pFwRule->put_ApplicationName(bstrRuleApplication);
212             if (SUCCEEDED(hr))
213             {
214                 DEBUGOUT(("INetFwRule Application Name Updated\n"));
215             }
216
217             hr = pFwRule->put_EdgeTraversal(VARIANT_TRUE);
218             if (SUCCEEDED(hr))
219             {
220                 DEBUGOUT(("INetFwRule Edge Traversal Updated\n"));
221             }
222
223             hr = pFwRule->put_InterfaceTypes(bstrInterfaceTypes);
224             if (SUCCEEDED(hr))
225             {
226                 DEBUGOUT(("INetFwRule Interface Types Updated\n"));
227             }
228         }
229
230         SysFreeString(bstrRuleName);
231         SysFreeString(bstrRuleDescription);
232         SysFreeString(bstrRuleLPorts);
233     }
234
235   Cleanup:
236
237     // Free BSTR's
238     SysFreeString(bstrRuleGroup);
239     SysFreeString(bstrRuleApplication);
240     SysFreeString(bstrRuleService);
241     SysFreeString(bstrInterfaceTypes);
242
243     // Release the INetFwRule object
244     if (pFwRule != NULL)
245     {
246         pFwRule->Release();
247     }
248
249     // Release the INetFwRules object
250     if (pFwRules != NULL)
251     {
252         pFwRules->Release();
253     }
254
255     // Release the INetFwPolicy2 object
256     if (pNetFwPolicy2 != NULL)
257     {
258         pNetFwPolicy2->Release();
259     }
260
261     // Uninitialize COM.
262     if (SUCCEEDED(hrComInit))
263     {
264         CoUninitialize();
265     }
266
267     return 0;
268 }
269
270
271 HRESULT icf_OpenFirewallProfile(INetFwProfile ** fwProfile)
272 {
273     HRESULT hr = S_OK;
274     INetFwMgr* fwMgr = NULL;
275     INetFwPolicy* fwPolicy = NULL;
276
277     *fwProfile = NULL;
278
279     // Create an instance of the firewall settings manager.
280     hr = CoCreateInstance(
281             __uuidof(NetFwMgr),
282             NULL,
283             CLSCTX_INPROC_SERVER,
284             __uuidof(INetFwMgr),
285             reinterpret_cast<void**>(static_cast<INetFwMgr**>(&fwMgr))
286             );
287     if (FAILED(hr))
288     {
289         DEBUGOUT(("Can't create fwMgr\n"));
290         goto error;
291     }
292
293     // Retrieve the local firewall policy.
294     hr = fwMgr->get_LocalPolicy(&fwPolicy);
295     if (FAILED(hr))
296     {
297         DEBUGOUT(("Cant get local policy\n"));
298         goto error;
299     }
300
301     // Retrieve the firewall profile currently in effect.
302     hr = fwPolicy->get_CurrentProfile(fwProfile);
303     if (FAILED(hr))
304     {
305         DEBUGOUT(("Can't get current profile\n"));
306         goto error;
307     }
308
309   error:
310
311     // Release the local firewall policy.
312     if (fwPolicy != NULL)
313     {
314         fwPolicy->Release();
315     }
316
317     // Release the firewall settings manager.
318     if (fwMgr != NULL)
319     {
320         fwMgr->Release();
321     }
322
323     return hr;
324 }
325
326 HRESULT icf_CheckAndAddPorts(INetFwProfile * fwProfile, global_afs_port_t * ports, int nPorts) {
327     INetFwOpenPorts * fwPorts = NULL;
328     INetFwOpenPort * fwPort = NULL;
329     HRESULT hr;
330     HRESULT rhr = S_OK; /* return value */
331     int i = 0;
332
333     hr = fwProfile->get_GloballyOpenPorts(&fwPorts);
334     if (FAILED(hr)) {
335         // Abort!
336         DEBUGOUT(("Can't get globallyOpenPorts\n"));
337         rhr = hr;
338         goto cleanup;
339     }
340
341     // go through the supplied ports
342     for (i=0; i<nPorts; i++) {
343         VARIANT_BOOL vbEnabled;
344         BSTR bstName = NULL;
345         BOOL bCreate = FALSE;
346         fwPort = NULL;
347
348         hr = fwPorts->Item(ports[i].n_port, ports[i].protocol, &fwPort);
349         if (SUCCEEDED(hr)) {
350             DEBUGOUTW((L"Found port for %S\n",ports[i].name));
351             hr = fwPort->get_Enabled(&vbEnabled);
352             if (SUCCEEDED(hr)) {
353                 if ( vbEnabled == VARIANT_FALSE ) {
354                     hr = fwPort->put_Enabled(VARIANT_TRUE);
355                     if (FAILED(hr)) {
356                         // failed. Mark as failure. Don't try to create the port either.
357                         rhr = hr;
358                     }
359                 } // else we are fine
360             } else {
361                 // Something is wrong with the port.
362                 // We try to create a new one thus overriding this faulty one.
363                 bCreate = TRUE;
364             }
365             fwPort->Release();
366             fwPort = NULL;
367         } else if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) {
368             DEBUGOUTW((L"Port not found for %S\n", ports[i].name));
369             bCreate = TRUE;
370         }
371
372         if (bCreate) {
373             DEBUGOUTW((L"Trying to create port %S\n",ports[i].name));
374             hr = CoCreateInstance( __uuidof(NetFwOpenPort),
375                                    NULL,
376                                    CLSCTX_INPROC_SERVER,
377                                    __uuidof(INetFwOpenPort),
378                                    reinterpret_cast<void**>
379                                    (static_cast<INetFwOpenPort**>(&fwPort))
380                                    );
381
382             if (FAILED(hr)) {
383                 DEBUGOUT(("Can't create port\n"));
384                 rhr = hr;
385             } else {
386                 DEBUGOUT(("Created port\n"));
387                 hr = fwPort->put_IpVersion( NET_FW_IP_VERSION_ANY );
388                 if (FAILED(hr)) {
389                     DEBUGOUT(("Can't set IpVersion\n"));
390                     rhr = hr;
391                     goto abandon_port;
392                 }
393
394                 hr = fwPort->put_Port( ports[i].n_port );
395                 if (FAILED(hr)) {
396                     DEBUGOUT(("Can't set Port\n"));
397                     rhr = hr;
398                     goto abandon_port;
399                 }
400
401                 hr = fwPort->put_Protocol( ports[i].protocol );
402                 if (FAILED(hr)) {
403                     DEBUGOUT(("Can't set Protocol\n"));
404                     rhr = hr;
405                     goto abandon_port;
406                 }
407
408                 hr = fwPort->put_Scope( NET_FW_SCOPE_ALL );
409                 if (FAILED(hr)) {
410                     DEBUGOUT(("Can't set Scope\n"));
411                     rhr = hr;
412                     goto abandon_port;
413                 }
414
415                 bstName = SysAllocString( ports[i].name );
416
417                 if (SysStringLen(bstName) == 0) {
418                     rhr = E_OUTOFMEMORY;
419                 } else {
420                     hr = fwPort->put_Name( bstName );
421                     if (FAILED(hr)) {
422                         DEBUGOUT(("Can't set Name\n"));
423                         rhr = hr;
424                         SysFreeString( bstName );
425                         goto abandon_port;
426                     }
427                 }
428
429                 SysFreeString( bstName );
430
431                 hr = fwPorts->Add( fwPort );
432                 if (FAILED(hr)) {
433                     DEBUGOUT(("Can't add port\n"));
434                     rhr = hr;
435                 } else
436                     DEBUGOUT(("Added port\n"));
437
438               abandon_port:
439                 fwPort->Release();
440             }
441         }
442     } // loop through ports
443
444     fwPorts->Release();
445
446   cleanup:
447
448     if (fwPorts != NULL)
449         fwPorts->Release();
450
451     return rhr;
452 }
453
454 long icf_CheckAndAddAFSPorts(int portset) {
455     HRESULT hr;
456     BOOL coInitialized = FALSE;
457     INetFwProfile * fwProfile = NULL;
458     global_afs_port_t * ports;
459     WCHAR * wServiceName;
460     int nports;
461     long code = 0;
462
463     if (portset == AFS_PORTSET_CLIENT) {
464         ports = afs_clientPorts;
465         nports = sizeof(afs_clientPorts) / sizeof(*afs_clientPorts);
466         wServiceName = L"TransarcAFSDaemon";
467     } else if (portset == AFS_PORTSET_SERVER) {
468         ports = afs_serverPorts;
469         nports = sizeof(afs_serverPorts) / sizeof(*afs_serverPorts);
470         wServiceName = L"TransarcAFSServer";
471     } else {
472         DEBUGOUT(("Invalid port set\n"));
473         return 1; /* Invalid port set */
474     }
475     hr = CoInitializeEx( NULL,
476                          COINIT_APARTMENTTHREADED | COINIT_DISABLE_OLE1DDE
477                          );
478
479     if (SUCCEEDED(hr) || RPC_E_CHANGED_MODE == hr)
480     {
481        coInitialized = TRUE;
482     }
483     // not necessarily catastrophic if the call failed.  We'll try to
484     // continue as if it succeeded.
485
486     hr = icf_CheckAndAddPorts2(wServiceName, ports, nports);
487     if (FAILED(hr)) {
488         DEBUGOUT(("INetFwProfile2 failed, trying INetFwProfile\n"));
489         hr = icf_OpenFirewallProfile(&fwProfile);
490         if (FAILED(hr)) {
491             // Ok. That didn't work.  This could be because the machine we
492             // are running on doesn't have Windows Firewall.  We'll return
493             // a failure to the caller, which shouldn't be taken to mean
494             // it's catastrophic.
495             DEBUGOUT(("Can't open Firewall profile\n"));
496             code = 2;
497             goto cleanup;
498         }
499
500         // Now that we have a firewall profile, we can start checking
501         // and adding the ports that we want.
502         hr = icf_CheckAndAddPorts(fwProfile, ports, nports);
503         if (FAILED(hr))
504             code = 3;
505     }
506
507   cleanup:
508     if (coInitialized) {
509         CoUninitialize();
510     }
511
512     return code;
513 }
514
515
516 #ifdef TESTMAIN
517 int main(int argc, char **argv) {
518     printf("Starting...\n");
519     if (icf_CheckAndAddAFSPorts(AFS_PORTSET_CLIENT))
520         printf("Failed\n");
521     else
522         printf("Succeeded\n");
523     printf("Done\n");
524     return 0;
525 }
526 #endif