windows-misc-20041122
[openafs.git] / src / WINNT / afsd / afsicf.cpp
1 /*
2
3 Copyright 2004 by the Massachusetts Institute of Technology
4
5 All rights reserved.
6
7 Permission to use, copy, modify, and distribute this software and its
8 documentation for any purpose and without fee is hereby granted,
9 provided that the above copyright notice appear in all copies and that
10 both that copyright notice and this permission notice appear in
11 supporting documentation, and that the name of the Massachusetts
12 Institute of Technology (M.I.T.) not be used in advertising or publicity
13 pertaining to distribution of the software without specific, written
14 prior permission.
15
16 M.I.T. DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING
17 ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT SHALL
18 M.I.T. BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR
19 ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
20 WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,
21 ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS
22 SOFTWARE.
23
24 */
25
26 #define _WIN32_DCOM
27 #include <windows.h>
28 #include <netfw.h>
29 #include <objbase.h>
30 #include <oleauto.h>
31 #include "afsicf.h"
32
33 //#define TESTMAIN
34
35 #ifdef TESTMAIN
36 #include<stdio.h>
37 #pragma comment(lib,"ole32.lib")
38 #pragma comment(lib,"oleaut32.lib")
39 #define DEBUGOUT(x) printf x
40 #else
41 #define DEBUGOUT(x)
42 #endif
43
44 /* an IPv4, enabled port with global scope */
45 struct global_afs_port_type {
46     LPWSTR      name;
47         LONG    port;
48         NET_FW_IP_PROTOCOL protocol;
49 };
50
51 typedef struct global_afs_port_type global_afs_port_t;
52
53 global_afs_port_t afs_clientPorts[] = {
54         { L"AFS CacheManager Callback (UDP)", 7001, NET_FW_IP_PROTOCOL_UDP },
55         { L"AFS CacheManager Callback (TCP)", 7001, NET_FW_IP_PROTOCOL_TCP }
56 };
57
58 global_afs_port_t afs_serverPorts[] = {
59         { L"AFS File Server (UDP)", 7000, NET_FW_IP_PROTOCOL_UDP },
60         { L"AFS File Server (TCP)", 7000, NET_FW_IP_PROTOCOL_TCP },
61         { L"AFS User & Group Database (UDP)", 7002, NET_FW_IP_PROTOCOL_UDP },
62         { L"AFS User & Group Database (TCP)", 7002, NET_FW_IP_PROTOCOL_TCP },
63         { L"AFS Volume Location Database (UDP)", 7003, NET_FW_IP_PROTOCOL_UDP },
64         { L"AFS Volume Location Database (TCP)", 7003, NET_FW_IP_PROTOCOL_TCP },
65         { L"AFS/Kerberos Authentication (UDP)", 7004, NET_FW_IP_PROTOCOL_UDP },
66         { L"AFS/Kerberos Authentication (TCP)", 7004, NET_FW_IP_PROTOCOL_TCP },
67         { L"AFS Volume Mangement (UDP)", 7005, NET_FW_IP_PROTOCOL_UDP },
68         { L"AFS Volume Mangement (TCP)", 7005, NET_FW_IP_PROTOCOL_TCP },
69         { L"AFS Error Interpretation (UDP)", 7006, NET_FW_IP_PROTOCOL_UDP },
70         { L"AFS Error Interpretation (TCP)", 7006, NET_FW_IP_PROTOCOL_TCP },
71         { L"AFS Basic Overseer (UDP)", 7007, NET_FW_IP_PROTOCOL_UDP },
72         { L"AFS Basic Overseer (TCP)", 7007, NET_FW_IP_PROTOCOL_TCP },
73         { L"AFS Server-to-server Updater (UDP)", 7008, NET_FW_IP_PROTOCOL_UDP },
74         { L"AFS Server-to-server Updater (TCP)", 7008, NET_FW_IP_PROTOCOL_TCP },
75         { L"AFS Remote Cache Manager (UDP)", 7009, NET_FW_IP_PROTOCOL_UDP },
76         { L"AFS Remote Cache Manager (TCP)", 7009, NET_FW_IP_PROTOCOL_TCP }
77 };
78
79 HRESULT icf_OpenFirewallProfile(INetFwProfile ** fwProfile) {
80     HRESULT hr = S_OK;
81     INetFwMgr* fwMgr = NULL;
82     INetFwPolicy* fwPolicy = NULL;
83
84     *fwProfile = NULL;
85
86     // Create an instance of the firewall settings manager.
87     hr = CoCreateInstance(
88             __uuidof(NetFwMgr),
89             NULL,
90             CLSCTX_INPROC_SERVER,
91             __uuidof(INetFwMgr),
92             reinterpret_cast<void**>(static_cast<INetFwMgr**>(&fwMgr))
93             );
94     if (FAILED(hr))
95     {
96                 DEBUGOUT(("Can't create fwMgr\n"));
97         goto error;
98     }
99
100     // Retrieve the local firewall policy.
101     hr = fwMgr->get_LocalPolicy(&fwPolicy);
102     if (FAILED(hr))
103     {
104                 DEBUGOUT(("Cant get local policy\n"));
105         goto error;
106     }
107
108     // Retrieve the firewall profile currently in effect.
109     hr = fwPolicy->get_CurrentProfile(fwProfile);
110     if (FAILED(hr))
111     {
112                 DEBUGOUT(("Can't get current profile\n"));
113         goto error;
114     }
115
116 error:
117
118     // Release the local firewall policy.
119     if (fwPolicy != NULL)
120     {
121         fwPolicy->Release();
122     }
123
124     // Release the firewall settings manager.
125     if (fwMgr != NULL)
126     {
127         fwMgr->Release();
128     }
129
130     return hr;
131 }
132
133 HRESULT icf_CheckAndAddPorts(INetFwProfile * fwProfile, global_afs_port_t * ports, int nPorts) {
134         INetFwOpenPorts * fwPorts = NULL;
135         INetFwOpenPort * fwPort = NULL;
136         HRESULT hr;
137         HRESULT rhr = S_OK; /* return value */
138         int i = 0;
139
140         hr = fwProfile->get_GloballyOpenPorts(&fwPorts);
141         if (FAILED(hr)) {
142                 // Abort!
143                 DEBUGOUT(("Can't get globallyOpenPorts\n"));
144                 rhr = hr;
145                 goto cleanup;
146         }
147
148         // go through the supplied ports
149         for (i=0; i<nPorts; i++) {
150                 VARIANT_BOOL vbEnabled;
151                 BSTR bstName = NULL;
152                 BOOL bCreate = FALSE;
153                 fwPort = NULL;
154
155                 hr = fwPorts->Item(ports[i].port, ports[i].protocol, &fwPort);
156                 if (SUCCEEDED(hr)) {
157                         DEBUGOUT(("Found port for %S\n",ports[i].name));
158             hr = fwPort->get_Enabled(&vbEnabled);
159                         if (SUCCEEDED(hr)) {
160                                 if ( vbEnabled == VARIANT_FALSE ) {
161                                         hr = fwPort->put_Enabled(VARIANT_TRUE);
162                                         if (FAILED(hr)) {
163                                                 // failed. Mark as failure. Don't try to create the port either.
164                                                 rhr = hr;
165                                         }
166                                 } // else we are fine
167                         } else {
168                 // Something is wrong with the port.
169                                 // We try to create a new one thus overriding this faulty one.
170                                 bCreate = TRUE;
171                         }
172                         fwPort->Release();
173                         fwPort = NULL;
174                 } else if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)) {
175                         DEBUGOUT(("Port not found for %S\n", ports[i].name));
176                         bCreate = TRUE;
177                 }
178
179                 if (bCreate) {
180                         DEBUGOUT(("Trying to create port %S\n",ports[i].name));
181                         hr = CoCreateInstance(
182                                 __uuidof(NetFwOpenPort),
183                                 NULL,
184                                 CLSCTX_INPROC_SERVER,
185                                 __uuidof(INetFwOpenPort),
186                                 reinterpret_cast<void**>
187                                         (static_cast<INetFwOpenPort**>(&fwPort))
188                                 );
189
190                         if (FAILED(hr)) {
191                                 DEBUGOUT(("Can't create port\n"));
192                 rhr = hr;
193                         } else {
194                                 DEBUGOUT(("Created port\n"));
195                                 hr = fwPort->put_IpVersion( NET_FW_IP_VERSION_ANY );
196                                 if (FAILED(hr)) {
197                                         DEBUGOUT(("Can't set IpVersion\n"));
198                                         rhr = hr;
199                                         goto abandon_port;
200                                 }
201
202                                 hr = fwPort->put_Port( ports[i].port );
203                                 if (FAILED(hr)) {
204                                         DEBUGOUT(("Can't set Port\n"));
205                                         rhr = hr;
206                                         goto abandon_port;
207                                 }
208
209                                 hr = fwPort->put_Protocol( ports[i].protocol );
210                                 if (FAILED(hr)) {
211                                         DEBUGOUT(("Can't set Protocol\n"));
212                                         rhr = hr;
213                                         goto abandon_port;
214                                 }
215
216                                 hr = fwPort->put_Scope( NET_FW_SCOPE_ALL );
217                                 if (FAILED(hr)) {
218                                         DEBUGOUT(("Can't set Scope\n"));
219                                         rhr = hr;
220                                         goto abandon_port;
221                                 }
222
223                                 bstName = SysAllocString( ports[i].name );
224
225                                 if (SysStringLen(bstName) == 0) {
226                                         rhr = E_OUTOFMEMORY;
227                                 } else {
228                                         hr = fwPort->put_Name( bstName );
229                                         if (FAILED(hr)) {
230                                                 DEBUGOUT(("Can't set Name\n"));
231                                                 rhr = hr;
232                                                 SysFreeString( bstName );
233                                                 goto abandon_port;
234                                         }
235                                 }
236
237                                 SysFreeString( bstName );
238
239                                 hr = fwPorts->Add( fwPort );
240                                 if (FAILED(hr)) {
241                                         DEBUGOUT(("Can't add port\n"));
242                                         rhr = hr;
243                                 } else
244                                         DEBUGOUT(("Added port\n"));
245
246 abandon_port:
247                                 fwPort->Release();
248                         }
249                 }
250         } // loop through ports
251
252         fwPorts->Release();
253
254 cleanup:
255
256         if (fwPorts != NULL)
257                 fwPorts->Release();
258
259         return rhr;
260 }
261
262 long icf_CheckAndAddAFSPorts(int portset) {
263         HRESULT hr;
264         BOOL coInitialized = FALSE;
265         INetFwProfile * fwProfile = NULL;
266         global_afs_port_t * ports;
267         int nports;
268         long code = 0;
269
270         if (portset == AFS_PORTSET_CLIENT) {
271                 ports = afs_clientPorts;
272                 nports = sizeof(afs_clientPorts) / sizeof(*afs_clientPorts);
273         } else if (portset == AFS_PORTSET_SERVER) {
274                 ports = afs_serverPorts;
275                 nports = sizeof(afs_serverPorts) / sizeof(*afs_serverPorts);
276         } else
277                 return 1; /* Invalid port set */
278
279         hr = CoInitializeEx(
280         NULL,
281         COINIT_APARTMENTTHREADED | COINIT_DISABLE_OLE1DDE
282         );
283
284         if (SUCCEEDED(hr) || RPC_E_CHANGED_MODE == hr)
285     {
286        coInitialized = TRUE;
287     }
288         // not necessarily catastrophic if the call failed.  We'll try to
289         // continue as if it succeeded.
290
291     hr = icf_OpenFirewallProfile(&fwProfile);
292         if (FAILED(hr)) {
293                 // Ok. That didn't work.  This could be because the machine we
294                 // are running on doesn't have Windows Firewall.  We'll return
295                 // a failure to the caller, which shouldn't be taken to mean
296                 // it's catastrophic.
297                 DEBUGOUT(("Can't open Firewall profile\n"));
298                 code = 1;
299                 goto cleanup;
300         }
301
302         // Now that we have a firewall profile, we can start checking
303         // and adding the ports that we want.
304         hr = icf_CheckAndAddPorts(fwProfile, ports, nports);
305         if (FAILED(hr))
306                 code = 1;
307
308 cleanup:
309         if (coInitialized) {
310                 CoUninitialize();
311         }
312
313         return code;
314 }
315
316
317 #ifdef TESTMAIN
318 int main(int argc, char **argv) {
319         printf("Starting...\n");
320     if (icf_CheckAndAddAFSPorts(AFS_PORTSET_CLIENT))
321                 printf("Failed\n");
322         else
323                 printf("Succeeded\n");
324         printf("Done\n");
325         return 0;
326 }
327 #endif