Windows: Prevent pioctl races from crashing afsd_service
[openafs.git] / src / WINNT / afsd / smb_ioctl.c
index c2c339a..6492f2c 100644 (file)
@@ -129,6 +129,7 @@ smb_IoctlPrepareRead(struct smb_fid *fidp, smb_ioctl_t *ioctlp, cm_user_t *userp
 
     if (ioctlp->ioctl.flags & CM_IOCTLFLAG_DATAIN) {
         ioctlp->ioctl.flags &= ~CM_IOCTLFLAG_DATAIN;
+        ioctlp->ioctl.flags |= CM_IOCTLFLAG_DATAOUT;
 
         /* do the call now, or fail if we didn't get an opcode, or
          * enough of an opcode.
@@ -138,26 +139,31 @@ smb_IoctlPrepareRead(struct smb_fid *fidp, smb_ioctl_t *ioctlp, cm_user_t *userp
         memcpy(&opcode, ioctlp->ioctl.inDatap, sizeof(afs_int32));
         ioctlp->ioctl.inDatap += sizeof(afs_int32);
 
-        osi_Log1(afsd_logp, "Ioctl opcode 0x%x", opcode);
-
+        osi_Log1(afsd_logp, "smb_IoctlPrepareRead opcode 0x%x", opcode);
         /* check for opcode out of bounds */
-        if (opcode < 0 || opcode >= SMB_IOCTL_MAXPROCS)
+        if (opcode < 0 || opcode >= SMB_IOCTL_MAXPROCS) {
+            osi_Log0(afsd_logp, "smb_IoctlPrepareRead - invalid opcode");
             return CM_ERROR_TOOBIG;
+        }
 
         /* check for no such proc */
        procp = smb_ioctlProcsp[opcode];
-        if (procp == NULL) 
-            return CM_ERROR_BADOP;
-
+        if (procp == NULL) {
+            osi_Log0(afsd_logp, "smb_IoctlPrepareRead - unassigned opcode");
+            return CM_ERROR_INVAL;
+        }
         /* otherwise, make the call */
         ioctlp->ioctl.outDatap += sizeof(afs_int32); /* reserve room for return code */
         code = (*procp)(ioctlp, userp);
-
-        osi_Log1(afsd_logp, "Ioctl return code 0x%x", code);
+        osi_Log1(afsd_logp, "smb_IoctlPrepareRead operation returns code 0x%x", code);
 
         /* copy in return code */
         memcpy(ioctlp->ioctl.outAllocp, &code, sizeof(afs_int32));
+    } else if (!(ioctlp->ioctl.flags & CM_IOCTLFLAG_DATAOUT)) {
+        osi_Log0(afsd_logp, "Ioctl invalid state - dataout expected");
+        return CM_ERROR_INVAL;
     }
+
     return 0;
 }
 
@@ -185,6 +191,7 @@ smb_IoctlPrepareWrite(smb_fid_t *fidp, smb_ioctl_t *ioctlp)
         ioctlp->ioctl.inDatap = ioctlp->ioctl.inAllocp;
         ioctlp->ioctl.outDatap = ioctlp->ioctl.outAllocp;
         ioctlp->ioctl.flags |= CM_IOCTLFLAG_DATAIN;
+        ioctlp->ioctl.flags &= ~CM_IOCTLFLAG_DATAOUT;
     }
 }       
 
@@ -219,6 +226,11 @@ smb_IoctlRead(smb_fid_t *fidp, smb_vc_t *vcp, smb_packet_t *inp, smb_packet_t *o
     }
 
     leftToCopy = (afs_int32)((iop->ioctl.outDatap - iop->ioctl.outAllocp) - iop->ioctl.outCopied);
+    if (leftToCopy < 0) {
+        osi_Log0(afsd_logp, "smb_IoctlRead leftToCopy went negative");
+        cm_ReleaseUser(userp);
+        return CM_ERROR_INVAL;
+    }
     if (count > leftToCopy)
         count = leftToCopy;
 
@@ -340,7 +352,7 @@ afs_int32
 smb_IoctlV3Read(smb_fid_t *fidp, smb_vc_t *vcp, smb_packet_t *inp, smb_packet_t *outp)
 {
     smb_ioctl_t *iop;
-    long count;
+    unsigned short count;
     afs_int32 code;
     long leftToCopy;
     char *op;
@@ -388,8 +400,13 @@ smb_IoctlV3Read(smb_fid_t *fidp, smb_vc_t *vcp, smb_packet_t *inp, smb_packet_t
     }
 
     leftToCopy = (long)((iop->ioctl.outDatap - iop->ioctl.outAllocp) - iop->ioctl.outCopied);
+    if (leftToCopy < 0) {
+        osi_Log0(afsd_logp, "smb_IoctlV3Read leftToCopy went negative");
+        cm_ReleaseUser(userp);
+        return CM_ERROR_INVAL;
+    }
     if (count > leftToCopy) 
-        count = leftToCopy;
+        count = (unsigned short)leftToCopy;
         
     /* 0 and 1 are reserved for request chaining, were setup by our caller,
      * and will be further filled in after we return.
@@ -472,6 +489,11 @@ smb_IoctlReadRaw(smb_fid_t *fidp, smb_vc_t *vcp, smb_packet_t *inp,
     }
 
     leftToCopy = (long)((iop->ioctl.outDatap - iop->ioctl.outAllocp) - iop->ioctl.outCopied);
+    if (leftToCopy < 0) {
+        osi_Log0(afsd_logp, "smb_IoctlReadRaw leftToCopy went negative");
+        code = CM_ERROR_INVAL;
+        goto done;
+    }
 
     ncbp = outp->ncbp;
     memset((char *)ncbp, 0, sizeof(NCB));