]> arthur.barton.de Git - netatalk.git/blobdiff - libatalk/dsi/dsi_stream.c
Fix data corruption bug
[netatalk.git] / libatalk / dsi / dsi_stream.c
index 394b707a044dd58827bfefd1e13274c77d6bf565..711a037b5098925c59c77198bb6350d390d9b5a3 100644 (file)
@@ -1,5 +1,6 @@
 /*
  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
+ * Copyright (c) 2010,2011,2012 Frank Lahm <franklahm@googlemail.com>
  * All rights reserved. See COPYRIGHT.
  *
  * this file provides the following functions:
 
 #include <stdio.h>
 #include <stdlib.h>
-
-#ifdef HAVE_UNISTD_H
 #include <unistd.h>
-#endif
-
 #include <string.h>
 #include <errno.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/uio.h>
 
+#ifdef HAVE_SENDFILEV
+#include <sys/sendfile.h>
+#endif
+
 #include <atalk/logger.h>
 #include <atalk/dsi.h>
-#include <netatalk/endian.h>
 #include <atalk/util.h>
 
-#define min(a,b)  ((a) < (b) ? (a) : (b))
-
 #ifndef MSG_MORE
 #define MSG_MORE 0x8000
 #endif
 #define MSG_DONTWAIT 0x40
 #endif
 
+/* Pack a DSI header in wire format */
+static void dsi_header_pack_reply(const DSI *dsi, char *buf)
+{
+    buf[0] = dsi->header.dsi_flags;
+    buf[1] = dsi->header.dsi_command;
+    memcpy(buf + 2, &dsi->header.dsi_requestID, sizeof(dsi->header.dsi_requestID));           
+    memcpy(buf + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
+    memcpy(buf + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
+    memcpy(buf + 12, &dsi->header.dsi_reserved, sizeof(dsi->header.dsi_reserved));
+}
+
 /*
  * afpd is sleeping too much while trying to send something.
  * May be there's no reader or the reader is also sleeping in write,
@@ -62,6 +71,9 @@ static int dsi_peek(DSI *dsi)
     maxfd = dsi->socket + 1;
 
     while (1) {
+        if (dsi->socket == -1)
+            /* eg dsi_disconnect() might have disconnected us */
+            return -1;
         FD_ZERO(&readfds);
         FD_ZERO(&writefds);
 
@@ -125,7 +137,7 @@ static int dsi_peek(DSI *dsi)
 /* 
  * Return all bytes up to count from dsi->buffer if there are any buffered there
  */
-static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
+static size_t from_buf(DSI *dsi, uint8_t *buf, size_t count)
 {
     size_t nbe = 0;
 
@@ -138,7 +150,7 @@ static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
     nbe = dsi->eof - dsi->start;
 
     if (nbe > 0) {
-        nbe = min((size_t)nbe, count);
+        nbe = MIN((size_t)nbe, count);
         memcpy(buf, dsi->start, nbe);
         dsi->start += nbe;
 
@@ -160,7 +172,7 @@ static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
  *    Note: this may return fewer bytes then requested in count !!
  * 3. If the buffer was empty, read from the socket.
  */
-static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
+static ssize_t buf_read(DSI *dsi, uint8_t *buf, size_t count)
 {
     ssize_t len;
 
@@ -184,7 +196,7 @@ static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
  * Get "length" bytes from buffer and/or socket. In order to avoid frequent small reads
  * this tries to read larger chunks (8192 bytes) into a buffer.
  */
-static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t length)
+static size_t dsi_buffered_stream_read(DSI *dsi, uint8_t *data, const size_t length)
 {
   size_t len;
   size_t buflen;
@@ -198,7 +210,7 @@ static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t le
   }
 
   /* fill the buffer with 8192 bytes or until buffer is full */
-  buflen = min(8192, dsi->end - dsi->eof);
+  buflen = MIN(8192, dsi->end - dsi->eof);
   if (buflen > 0) {
       ssize_t ret;
       ret = read(dsi->socket, dsi->eof, buflen);
@@ -236,7 +248,7 @@ static void unblock_sig(DSI *dsi)
  * Communication error with the client, enter disconnected state
  *
  * 1. close the socket
- * 2. set the DSI_DISCONNECTED flag
+ * 2. set the DSI_DISCONNECTED flag, remove possible sleep flags
  *
  * @returns  0 if successfully entered disconnected state
  *          -1 if ppid is 1 which means afpd master died
@@ -244,9 +256,11 @@ static void unblock_sig(DSI *dsi)
  */
 int dsi_disconnect(DSI *dsi)
 {
+    LOG(log_note, logtype_dsi, "dsi_disconnect: entering disconnected state");
     dsi->proto_close(dsi);          /* 1 */
-    dsi->flags |= DSI_DISCONNECTED; /* 2 */
-    if (getppid() == 1 || geteuid() == 0)
+    dsi->flags &= ~(DSI_SLEEPING | DSI_EXTSLEEP); /* 2 */
+    dsi->flags |= DSI_DISCONNECTED;
+    if (geteuid() == 0)
         return -1;
     return 0;
 }
@@ -259,18 +273,23 @@ ssize_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
 {
   size_t written;
   ssize_t len;
-  unsigned int flags = 0;
+  unsigned int flags;
 
   dsi->in_write++;
   written = 0;
 
-  LOG(log_maxdebug, logtype_dsi, "dsi_stream_write: sending %u bytes", length);
+  LOG(log_maxdebug, logtype_dsi, "dsi_stream_write(send: %zd bytes): START", length);
 
   if (dsi->flags & DSI_DISCONNECTED)
       return -1;
 
+  if (mode & DSI_MSG_MORE)
+      flags = MSG_MORE;
+  else
+      flags = 0;
+
   while (written < length) {
-      len = send(dsi->socket, (u_int8_t *) data + written, length - written, flags);
+      len = send(dsi->socket, (uint8_t *) data + written, length - written, flags);
       if (len >= 0) {
           written += len;
           continue;
@@ -303,62 +322,125 @@ ssize_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
   }
 
   dsi->write_count += written;
+  LOG(log_maxdebug, logtype_dsi, "dsi_stream_write(send: %zd bytes): END", length);
 
 exit:
   dsi->in_write--;
   return written;
 }
 
-
 /* ---------------------------------
 */
 #ifdef WITH_SENDFILE
-ssize_t dsi_stream_read_file(DSI *dsi, int fromfd, off_t offset, const size_t length)
+ssize_t dsi_stream_read_file(DSI *dsi, const int fromfd, off_t offset, const size_t length, const int err)
 {
-  size_t written;
-  ssize_t len;
+    int ret = 0;
+    size_t written = 0;
+    size_t total = length;
+    ssize_t len;
+    off_t pos = offset;
+    char block[DSI_BLOCKSIZ];
+#ifdef HAVE_SENDFILEV
+    int sfvcnt;
+    struct sendfilevec vec[2];
+    ssize_t nwritten;
+#endif
 
-  LOG(log_maxdebug, logtype_dsi, "dsi_stream_read_file: sending %u bytes", length);
+    LOG(log_maxdebug, logtype_dsi, "dsi_stream_read_file(off: %jd, len: %zu)", (intmax_t)offset, length);
 
-  if (dsi->flags & DSI_DISCONNECTED)
-      return -1;
+    if (dsi->flags & DSI_DISCONNECTED)
+        return -1;
 
-  dsi->in_write++;
-  written = 0;
+    dsi->in_write++;
+
+    dsi->flags |= DSI_NOREPLY;
+    dsi->header.dsi_flags = DSIFL_REPLY;
+    dsi->header.dsi_len = htonl(length);
+    dsi->header.dsi_code = htonl(err);
+    dsi_header_pack_reply(dsi, block);
+
+#ifdef HAVE_SENDFILEV
+    total += DSI_BLOCKSIZ;
+    sfvcnt = 2;
+    vec[0].sfv_fd = SFV_FD_SELF;
+    vec[0].sfv_flag = 0;
+    /* Cast to unsigned long to prevent sign extension of the
+     * pointer value for the LFS case; see Apache PR 39463. */
+    vec[0].sfv_off = (unsigned long)block;
+    vec[0].sfv_len = DSI_BLOCKSIZ;
+    vec[1].sfv_fd = fromfd;
+    vec[1].sfv_flag = 0;
+    vec[1].sfv_off = offset;
+    vec[1].sfv_len = length;
+#else
+    dsi_stream_write(dsi, block, sizeof(block), DSI_MSG_MORE);
+#endif
 
-  while (written < length) {
-    len = sys_sendfile(dsi->socket, fromfd, &offset, length - written);
-        
-    if (len < 0) {
-      if (errno == EINTR)
-          continue;
-      if (errno == EINVAL || errno == ENOSYS)
-          return -1;
-          
-      if (errno == EAGAIN || errno == EWOULDBLOCK) {
-          if (dsi_peek(dsi)) {
-              /* can't go back to blocking mode, exit, the next read
-                 will return with an error and afpd will die.
-              */
-              break;
-          }
-          continue;
-      }
-      LOG(log_error, logtype_dsi, "dsi_stream_read_file: %s", strerror(errno));
-      break;
-    }
-    else if (!len) {
-        /* afpd is going to exit */
-        errno = EIO;
-        return -1; /* I think we're at EOF here... */
-    }
-    else 
+    while (written < total) {
+#ifdef HAVE_SENDFILEV
+        nwritten = 0;
+        len = sendfilev(dsi->socket, vec, sfvcnt, &nwritten);
+#else
+        len = sys_sendfile(dsi->socket, fromfd, &pos, total - written);
+#endif
+        if (len < 0) {
+            switch (errno) {
+            case EINTR:
+            case EAGAIN:
+                len = 0;
+#ifdef HAVE_SENDFILEV
+                len = (size_t)nwritten;
+#else
+#if defined(SOLARIS) || defined(FREEBSD)
+                if (pos > offset) {
+                    /* we actually have sent sth., adjust counters and keep trying */
+                    len = pos - offset;
+                    offset = pos;
+                }
+#endif /* defined(SOLARIS) || defined(FREEBSD) */
+#endif /* HAVE_SENDFILEV */
+
+                if (dsi_peek(dsi) != 0) {
+                    ret = -1;
+                    goto exit;
+                }
+                break;
+            default:
+                LOG(log_error, logtype_dsi, "dsi_stream_read_file: %s", strerror(errno));
+                ret = -1;
+                goto exit;
+            }
+        } else if (len == 0) {
+            /* afpd is going to exit */
+            ret = -1;
+            goto exit;
+        }
+#ifdef HAVE_SENDFILEV
+        if (sfvcnt == 2 && len >= vec[0].sfv_len) {
+            vec[1].sfv_off += len - vec[0].sfv_len;
+            vec[1].sfv_len -= len - vec[0].sfv_len;
+
+            vec[0] = vec[1];
+            sfvcnt = 1;
+        } else {
+            vec[0].sfv_off += len;
+            vec[0].sfv_len -= len;
+        }
+#endif  /* HAVE_SENDFILEV */
+        LOG(log_maxdebug, logtype_dsi, "dsi_stream_read_file: wrote: %zd", len);
         written += len;
-  }
+    }
+#ifdef HAVE_SENDFILEV
+    written -= DSI_BLOCKSIZ;
+#endif
+    dsi->write_count += written;
 
-  dsi->write_count += written;
-  dsi->in_write--;
-  return written;
+exit:
+    dsi->in_write--;
+    LOG(log_maxdebug, logtype_dsi, "dsi_stream_read_file: written: %zd", written);
+    if (ret != 0)
+        return -1;
+    return written;
 }
 #endif
 
@@ -381,14 +463,18 @@ size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
 
   stored = 0;
   while (stored < length) {
-      len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
+      len = buf_read(dsi, (uint8_t *) data + stored, length - stored);
       if (len == -1 && (errno == EINTR || errno == EAGAIN)) {
-          LOG(log_debug, logtype_dsi, "dsi_stream_read: select read loop");
+          LOG(log_maxdebug, logtype_dsi, "dsi_stream_read: select read loop");
           continue;
       } else if (len > 0) {
           stored += len;
       } else { /* eof or error */
           /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
+#if 0
+          if (errno == ECONNRESET)
+              dsi->flags |= DSI_GOT_ECONNRESET;
+#endif
           if (len || stored || dsi->read_count) {
               if (! (dsi->flags & DSI_DISCONNECTED)) {
                   LOG(log_error, logtype_dsi, "dsi_stream_read: len:%d, %s",
@@ -417,22 +503,15 @@ int dsi_stream_send(DSI *dsi, void *buf, size_t length)
   size_t towrite;
   ssize_t len;
 
-  LOG(log_maxdebug, logtype_dsi, "dsi_stream_send: %u bytes",
-      length ? length : sizeof(block));
+  LOG(log_maxdebug, logtype_dsi, "dsi_stream_send(%u bytes): START", length);
 
   if (dsi->flags & DSI_DISCONNECTED)
       return 0;
 
-  block[0] = dsi->header.dsi_flags;
-  block[1] = dsi->header.dsi_command;
-  memcpy(block + 2, &dsi->header.dsi_requestID, 
-        sizeof(dsi->header.dsi_requestID));
-  memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
-  memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
-  memcpy(block + 12, &dsi->header.dsi_reserved,
-        sizeof(dsi->header.dsi_reserved));
+  dsi_header_pack_reply(dsi, block);
 
   if (!length) { /* just write the header */
+      LOG(log_maxdebug, logtype_dsi, "dsi_stream_send(%u bytes): DSI header, no data", sizeof(block));
     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
     return length; /* really 0 on failure, 1 on success */
   }
@@ -454,7 +533,7 @@ int dsi_stream_send(DSI *dsi, void *buf, size_t length)
           break;
       else if (len < 0) { /* error */
           if (errno == EAGAIN || errno == EWOULDBLOCK) {
-              if (!dsi_peek(dsi)) {
+              if (dsi_peek(dsi) == 0) {
                   continue;
               }
           }
@@ -476,52 +555,52 @@ int dsi_stream_send(DSI *dsi, void *buf, size_t length)
           iov[1].iov_len -= len;
       }
   }
+
+  LOG(log_maxdebug, logtype_dsi, "dsi_stream_send(%u bytes): END", length);
   
   unblock_sig(dsi);
   return 1;
 }
 
 
-/* ---------------------------------------
- * read data. function on success. 0 on failure. data length gets
- * stored in length variable. this should really use size_t's, but
- * that would require changes elsewhere. */
-int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
-                      size_t *rlength)
+/*!
+ * Read DSI command and data
+ *
+ * @param  dsi   (rw) DSI handle
+ *
+ * @return    DSI function on success, 0 on failure
+ */
+int dsi_stream_receive(DSI *dsi)
 {
   char block[DSI_BLOCKSIZ];
 
-  LOG(log_maxdebug, logtype_dsi, "dsi_stream_receive: %u bytes", ilength);
+  LOG(log_maxdebug, logtype_dsi, "dsi_stream_receive: START");
 
   if (dsi->flags & DSI_DISCONNECTED)
       return 0;
 
   /* read in the header */
-  if (dsi_buffered_stream_read(dsi, (u_int8_t *)block, sizeof(block)) != sizeof(block)) 
+  if (dsi_buffered_stream_read(dsi, (uint8_t *)block, sizeof(block)) != sizeof(block)) 
     return 0;
 
   dsi->header.dsi_flags = block[0];
   dsi->header.dsi_command = block[1];
-  /* FIXME, not the right place, 
-     but we get a server disconnect without reason in the log
-  */
-  if (!block[1]) {
-      LOG(log_error, logtype_dsi, "dsi_stream_receive: invalid packet, fatal");
+
+  if (dsi->header.dsi_command == 0)
       return 0;
-  }
 
-  memcpy(&dsi->header.dsi_requestID, block + 2, 
-        sizeof(dsi->header.dsi_requestID));
+  memcpy(&dsi->header.dsi_requestID, block + 2, sizeof(dsi->header.dsi_requestID));
   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
-  memcpy(&dsi->header.dsi_reserved, block + 12,
-        sizeof(dsi->header.dsi_reserved));
+  memcpy(&dsi->header.dsi_reserved, block + 12, sizeof(dsi->header.dsi_reserved));
   dsi->clientID = ntohs(dsi->header.dsi_requestID);
   
   /* make sure we don't over-write our buffers. */
-  *rlength = min(ntohl(dsi->header.dsi_len), ilength);
-  if (dsi_stream_read(dsi, buf, *rlength) != *rlength
+  dsi->cmdlen = MIN(ntohl(dsi->header.dsi_len), dsi->server_quantum);
+  if (dsi_stream_read(dsi, dsi->commands, dsi->cmdlen) != dsi->cmdlen
     return 0;
 
+  LOG(log_debug, logtype_dsi, "dsi_stream_receive: DSI cmdlen: %zd", dsi->cmdlen);
+
   return block[1];
 }