Windows: prevent cm_server races
[openafs.git] / src / WINNT / afsd / cm_server.c
index b7a4081..a294824 100644 (file)
@@ -131,13 +131,13 @@ cm_PingServer(cm_server_t *tsp)
        lock_ObtainMutex(&tsp->mx);
        tsp->waitCount--;
        if (tsp->waitCount == 0)
-           tsp->flags &= ~CM_SERVERFLAG_PINGING;
+           _InterlockedAnd(&tsp->flags, ~CM_SERVERFLAG_PINGING);
        else
            osi_Wakeup((LONG_PTR)tsp);
        lock_ReleaseMutex(&tsp->mx);
        return;
     }
-    tsp->flags |= CM_SERVERFLAG_PINGING;
+    _InterlockedOr(&tsp->flags, CM_SERVERFLAG_PINGING);
     wasDown = tsp->flags & CM_SERVERFLAG_DOWN;
     afs_inet_ntoa_r(tsp->addr.sin_addr.S_un.S_addr, hoststr);
     lock_ReleaseMutex(&tsp->mx);
@@ -174,7 +174,7 @@ cm_PingServer(cm_server_t *tsp)
     lock_ObtainMutex(&tsp->mx);
     if (code >= 0 || code == RXGEN_OPCODE) {
        /* mark server as up */
-       tsp->flags &= ~CM_SERVERFLAG_DOWN;
+       _InterlockedAnd(&tsp->flags, ~CM_SERVERFLAG_DOWN);
         tsp->downTime = 0;
 
        /* we currently handle 32-bits of capabilities */
@@ -218,7 +218,7 @@ cm_PingServer(cm_server_t *tsp)
     } else {
        /* mark server as down */
         if (!(tsp->flags & CM_SERVERFLAG_DOWN)) {
-            tsp->flags |= CM_SERVERFLAG_DOWN;
+            _InterlockedOr(&tsp->flags, CM_SERVERFLAG_DOWN);
             tsp->downTime = time(NULL);
         }
        if (code != VRESTARTING) {
@@ -257,7 +257,7 @@ cm_PingServer(cm_server_t *tsp)
     }
 
     if (tsp->waitCount == 0)
-       tsp->flags &= ~CM_SERVERFLAG_PINGING;
+       _InterlockedAnd(&tsp->flags, ~CM_SERVERFLAG_PINGING);
     else
        osi_Wakeup((LONG_PTR)tsp);
     lock_ReleaseMutex(&tsp->mx);
@@ -413,7 +413,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
                 continue;
             }
 
-            tsp->flags |= CM_SERVERFLAG_PINGING;
+            _InterlockedOr(&tsp->flags, CM_SERVERFLAG_PINGING);
             lock_ReleaseMutex(&tsp->mx);
 
             serversp[nconns] = tsp;
@@ -457,7 +457,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
 
             if (results[i] >= 0 || results[i] == RXGEN_OPCODE)  {
                 /* mark server as up */
-                tsp->flags &= ~CM_SERVERFLAG_DOWN;
+                _InterlockedAnd(&tsp->flags, ~CM_SERVERFLAG_DOWN);
                 tsp->downTime = 0;
 
                 /* we currently handle 32-bits of capabilities */
@@ -502,7 +502,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
             } else {
                 /* mark server as down */
                 if (!(tsp->flags & CM_SERVERFLAG_DOWN)) {
-                    tsp->flags |= CM_SERVERFLAG_DOWN;
+                    _InterlockedOr(&tsp->flags, CM_SERVERFLAG_DOWN);
                     tsp->downTime = time(NULL);
                 }
                 if (code != VRESTARTING) {
@@ -542,7 +542,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
             }
 
             if (tsp->waitCount == 0)
-                tsp->flags &= ~CM_SERVERFLAG_PINGING;
+                _InterlockedAnd(&tsp->flags, ~CM_SERVERFLAG_PINGING);
             else
                 osi_Wakeup((LONG_PTR)tsp);
 
@@ -577,7 +577,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
                 continue;
             }
 
-            tsp->flags |= CM_SERVERFLAG_PINGING;
+            _InterlockedOr(&tsp->flags, CM_SERVERFLAG_PINGING);
             lock_ReleaseMutex(&tsp->mx);
 
             serversp[nconns] = tsp;
@@ -622,7 +622,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
 
             if (results[i] >= 0)  {
                 /* mark server as up */
-                tsp->flags &= ~CM_SERVERFLAG_DOWN;
+                _InterlockedAnd(&tsp->flags, ~CM_SERVERFLAG_DOWN);
                 tsp->downTime = 0;
                 tsp->capabilities = 0;
 
@@ -634,7 +634,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
             } else {
                 /* mark server as down */
                 if (!(tsp->flags & CM_SERVERFLAG_DOWN)) {
-                    tsp->flags |= CM_SERVERFLAG_DOWN;
+                    _InterlockedOr(&tsp->flags, CM_SERVERFLAG_DOWN);
                     tsp->downTime = time(NULL);
                 }
                 if (code != VRESTARTING) {
@@ -650,7 +650,7 @@ static void cm_CheckServersMulti(afs_uint32 flags, cm_cell_t *cellp)
             }
 
             if (tsp->waitCount == 0)
-                tsp->flags &= ~CM_SERVERFLAG_PINGING;
+                _InterlockedAnd(&tsp->flags, ~CM_SERVERFLAG_PINGING);
             else
                 osi_Wakeup((LONG_PTR)tsp);
 
@@ -748,9 +748,9 @@ void cm_SetServerNo64Bit(cm_server_t * serverp, int no64bit)
 {
     lock_ObtainMutex(&serverp->mx);
     if (no64bit)
-        serverp->flags |= CM_SERVERFLAG_NO64BIT;
+        _InterlockedOr(&serverp->flags, CM_SERVERFLAG_NO64BIT);
     else
-        serverp->flags &= ~CM_SERVERFLAG_NO64BIT;
+        _InterlockedAnd(&serverp->flags, ~CM_SERVERFLAG_NO64BIT);
     lock_ReleaseMutex(&serverp->mx);
 }
 
@@ -758,9 +758,9 @@ void cm_SetServerNoInlineBulk(cm_server_t * serverp, int no)
 {
     lock_ObtainMutex(&serverp->mx);
     if (no)
-        serverp->flags |= CM_SERVERFLAG_NOINLINEBULK;
+        _InterlockedOr(&serverp->flags, CM_SERVERFLAG_NOINLINEBULK);
     else
-        serverp->flags &= ~CM_SERVERFLAG_NOINLINEBULK;
+        _InterlockedAnd(&serverp->flags, ~CM_SERVERFLAG_NOINLINEBULK);
     lock_ReleaseMutex(&serverp->mx);
 }
 
@@ -834,6 +834,20 @@ cm_server_t *cm_NewServer(struct sockaddr_in *socketp, int type, cm_cell_t *cell
 
     osi_assertx(socketp->sin_family == AF_INET, "unexpected socket family");
 
+    lock_ObtainWrite(&cm_serverLock);  /* get server lock */
+    tsp = cm_FindServer(socketp, type, TRUE);
+    if (tsp) {
+        /* we might have found a server created by set server prefs */
+        if (uuidp && !afs_uuid_is_nil(uuidp) &&
+            !(tsp->flags & CM_SERVERFLAG_UUID))
+        {
+            tsp->uuid = *uuidp;
+            _InterlockedOr(&tsp->flags, CM_SERVERFLAG_UUID);
+        }
+        lock_ReleaseWrite(&cm_serverLock);
+        return tsp;
+    }
+
     tsp = malloc(sizeof(*tsp));
     if (tsp) {
         memset(tsp, 0, sizeof(*tsp));
@@ -841,7 +855,7 @@ cm_server_t *cm_NewServer(struct sockaddr_in *socketp, int type, cm_cell_t *cell
         tsp->cellp = cellp;
         if (uuidp && !afs_uuid_is_nil(uuidp)) {
             tsp->uuid = *uuidp;
-            tsp->flags |= CM_SERVERFLAG_UUID;
+            _InterlockedOr(&tsp->flags, CM_SERVERFLAG_UUID);
         }
         tsp->refCount = 1;
         lock_InitializeMutex(&tsp->mx, "cm_server_t mutex", LOCK_HIERARCHY_SERVER);
@@ -849,7 +863,6 @@ cm_server_t *cm_NewServer(struct sockaddr_in *socketp, int type, cm_cell_t *cell
 
         cm_SetServerPrefs(tsp);
 
-        lock_ObtainWrite(&cm_serverLock);      /* get server lock */
         tsp->allNextp = cm_allServersp;
         cm_allServersp = tsp;
 
@@ -861,23 +874,25 @@ cm_server_t *cm_NewServer(struct sockaddr_in *socketp, int type, cm_cell_t *cell
             cm_numFileServers++;
             break;
         }
+    }
+    lock_ReleaseWrite(&cm_serverLock);         /* release server lock */
 
-        lock_ReleaseWrite(&cm_serverLock);     /* release server lock */
-
-        if ( !(flags & CM_FLAG_NOPROBE) ) {
-            tsp->flags |= CM_SERVERFLAG_DOWN;  /* assume down; ping will mark up if available */
-            cm_PingServer(tsp);                        /* Obtain Capabilities and check up/down state */
-        }
+    if (!(flags & CM_FLAG_NOPROBE) && tsp) {
+        _InterlockedOr(&tsp->flags, CM_SERVERFLAG_DOWN);       /* assume down; ping will mark up if available */
+        cm_PingServer(tsp);                                    /* Obtain Capabilities and check up/down state */
     }
+
     return tsp;
 }
 
 cm_server_t *
-cm_FindServerByIP(afs_uint32 ipaddr, unsigned short port, int type)
+cm_FindServerByIP(afs_uint32 ipaddr, unsigned short port, int type, int locked)
 {
     cm_server_t *tsp;
 
-    lock_ObtainRead(&cm_serverLock);
+    if (!locked)
+        lock_ObtainRead(&cm_serverLock);
+
     for (tsp = cm_allServersp; tsp; tsp = tsp->allNextp) {
         if (tsp->type == type &&
             tsp->addr.sin_addr.S_un.S_addr == ipaddr &&
@@ -889,17 +904,20 @@ cm_FindServerByIP(afs_uint32 ipaddr, unsigned short port, int type)
     if (tsp)
         cm_GetServerNoLock(tsp);
 
-    lock_ReleaseRead(&cm_serverLock);
+    if (!locked)
+        lock_ReleaseRead(&cm_serverLock);
 
     return tsp;
 }
 
 cm_server_t *
-cm_FindServerByUuid(afsUUID *serverUuid, int type)
+cm_FindServerByUuid(afsUUID *serverUuid, int type, int locked)
 {
     cm_server_t *tsp;
 
-    lock_ObtainRead(&cm_serverLock);
+    if (locked)
+        lock_ObtainRead(&cm_serverLock);
+
     for (tsp = cm_allServersp; tsp; tsp = tsp->allNextp) {
         if (tsp->type == type && !afs_uuid_equal(&tsp->uuid, serverUuid))
             break;
@@ -909,35 +927,18 @@ cm_FindServerByUuid(afsUUID *serverUuid, int type)
     if (tsp)
         cm_GetServerNoLock(tsp);
 
-    lock_ReleaseRead(&cm_serverLock);
+    if (!locked)
+        lock_ReleaseRead(&cm_serverLock);
 
     return tsp;
 }
 
 /* find a server based on its properties */
-cm_server_t *cm_FindServer(struct sockaddr_in *addrp, int type)
+cm_server_t *cm_FindServer(struct sockaddr_in *addrp, int type, int locked)
 {
-    cm_server_t *tsp;
-
     osi_assertx(addrp->sin_family == AF_INET, "unexpected socket value");
 
-    lock_ObtainRead(&cm_serverLock);
-    for (tsp = cm_allServersp; tsp; tsp=tsp->allNextp) {
-        if (tsp->type == type &&
-            tsp->addr.sin_addr.s_addr == addrp->sin_addr.s_addr &&
-            (tsp->addr.sin_port == addrp->sin_port || tsp->addr.sin_port == 0))
-            break;
-    }
-
-    /* bump ref count if we found the server */
-    if (tsp)
-        cm_GetServerNoLock(tsp);
-
-    /* drop big table lock */
-    lock_ReleaseRead(&cm_serverLock);
-
-    /* return what we found */
-    return tsp;
+    return cm_FindServerByIP(addrp->sin_addr.s_addr, addrp->sin_port, type, locked);
 }
 
 cm_server_vols_t *cm_NewServerVols(void) {