]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
Convert DSI to nonblocking
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
3  * All rights reserved. See COPYRIGHT.
4  *
5  * this file provides the following functions:
6  * dsi_stream_write:    just write a bunch of bytes.
7  * dsi_stream_read:     just read a bunch of bytes.
8  * dsi_stream_send:     send a DSI header + data.
9  * dsi_stream_receive:  read a DSI header + data.
10  */
11
12 #ifdef HAVE_CONFIG_H
13 #include "config.h"
14 #endif /* HAVE_CONFIG_H */
15
16 #include <stdio.h>
17 #include <stdlib.h>
18
19 #ifdef HAVE_UNISTD_H
20 #include <unistd.h>
21 #endif
22
23 #include <string.h>
24 #include <errno.h>
25 #include <sys/types.h>
26 #include <sys/socket.h>
27 #include <sys/uio.h>
28
29 #include <atalk/logger.h>
30 #include <atalk/dsi.h>
31 #include <netatalk/endian.h>
32 #include <atalk/util.h>
33
34 #define min(a,b)  ((a) < (b) ? (a) : (b))
35
36 #ifndef MSG_MORE
37 #define MSG_MORE 0x8000
38 #endif
39
40 #ifndef MSG_DONTWAIT
41 #define MSG_DONTWAIT 0x40
42 #endif
43
44 /* ---------------------- 
45    afpd is sleeping too much while trying to send something.
46    May be there's no reader or the reader is also sleeping in write,
47    look if there's some data for us to read, hopefully it will wake up
48    the reader so we can write again.
49 */
50 static int dsi_peek(DSI *dsi)
51 {
52     fd_set readfds, writefds;
53     int    len;
54     int    maxfd;
55     int    ret;
56
57     LOG(log_debug, logtype_dsi, "dsi_peek");
58
59     FD_ZERO(&readfds);
60     FD_ZERO(&writefds);
61     FD_SET( dsi->socket, &readfds);
62     FD_SET( dsi->socket, &writefds);
63     maxfd = dsi->socket +1;
64
65     while (1) {
66         FD_SET( dsi->socket, &readfds);
67         FD_SET( dsi->socket, &writefds);
68
69         /* No timeout: if there's nothing to read nor nothing to write,
70          * we've got nothing to do at all */
71         if ((ret = select( maxfd, &readfds, &writefds, NULL, NULL)) <= 0) {
72             if (ret == -1 && errno == EINTR)
73                 /* we might have been interrupted by out timer, so restart select */
74                 continue;
75             /* give up */
76             LOG(log_error, logtype_dsi, "dsi_peek: unexpected select return: %d %s",
77                 ret, ret < 0 ? strerror(errno) : "");
78             return -1;
79         }
80
81         /* Check if there's sth to read, hopefully reading that will unblock the client */
82         if (FD_ISSET(dsi->socket, &readfds)) {
83             len = dsi->end - dsi->eof;
84
85             if (len <= 0) {
86                 /* ouch, our buffer is full ! fall back to blocking IO 
87                  * could block and disconnect but it's better than a cpu hog */
88                 LOG(log_warning, logtype_dsi, "dsi_peek: read buffer is full");
89                 break;
90             }
91
92             if ((len = read(dsi->socket, dsi->eof, len)) <= 0) {
93                 if (len == 0) {
94                     LOG(log_error, logtype_dsi, "dsi_peek: EOF");
95                     return -1;
96                 }
97                 LOG(log_error, logtype_dsi, "dsi_peek: read: %s", strerror(errno));
98                 if (errno == EAGAIN)
99                     continue;
100                 return -1;
101             }
102             LOG(log_debug, logtype_dsi, "dsi_peek: read %d bytes", len);
103
104             dsi->eof += len;
105         }
106
107         if (FD_ISSET(dsi->socket, &writefds)) {
108             /* we can write again */
109             LOG(log_debug, logtype_dsi, "dsi_peek: can write again");
110             break;
111         }
112     }
113
114     return 0;
115 }
116
117 /* ------------------------------
118  * write raw data. return actual bytes read. checks against EINTR
119  * aren't necessary if all of the signals have SA_RESTART
120  * specified. */
121 ssize_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
122 {
123   size_t written;
124   ssize_t len;
125   unsigned int flags = 0;
126
127   dsi->in_write++;
128   written = 0;
129
130   LOG(log_maxdebug, logtype_dsi, "dsi_stream_write: sending %u bytes", length);
131
132   while (written < length) {
133       len = send(dsi->socket, (u_int8_t *) data + written, length - written, flags);
134       if (len >= 0) {
135           written += len;
136           continue;
137       }
138
139       if (errno == EINTR)
140           continue;
141
142       if (errno == EAGAIN || errno == EWOULDBLOCK) {
143           if (mode == DSI_NOWAIT && written == 0) {
144               /* DSI_NOWAIT is used by attention give up in this case. */
145               written = -1;
146               goto exit;
147           }
148
149           /* Try to read sth. in order to break up possible deadlock */
150           if (dsi_peek(dsi) != 0) {
151               written = -1;
152               goto exit;
153           }
154           /* Now try writing again */
155           continue;
156       }
157
158       LOG(log_error, logtype_dsi, "dsi_stream_write: %s", strerror(errno));
159       written = -1;
160       goto exit;
161   }
162
163   dsi->write_count += written;
164
165 exit:
166   dsi->in_write--;
167   return written;
168 }
169
170
171 /* ---------------------------------
172 */
173 #ifdef WITH_SENDFILE
174 ssize_t dsi_stream_read_file(DSI *dsi, int fromfd, off_t offset, const size_t length)
175 {
176   size_t written;
177   ssize_t len;
178
179   LOG(log_maxdebug, logtype_dsi, "dsi_stream_read_file: sending %u bytes", length);
180
181   dsi->in_write++;
182   written = 0;
183
184   while (written < length) {
185     len = sys_sendfile(dsi->socket, fromfd, &offset, length - written);
186         
187     if (len < 0) {
188       if (errno == EINTR)
189           continue;
190       if (errno == EINVAL || errno == ENOSYS)
191           return -1;
192           
193       if (errno == EAGAIN || errno == EWOULDBLOCK) {
194           if (dsi_peek(dsi)) {
195               /* can't go back to blocking mode, exit, the next read
196                  will return with an error and afpd will die.
197               */
198               break;
199           }
200           continue;
201       }
202       LOG(log_error, logtype_dsi, "dsi_stream_read_file: %s", strerror(errno));
203       break;
204     }
205     else if (!len) {
206         /* afpd is going to exit */
207         errno = EIO;
208         return -1; /* I think we're at EOF here... */
209     }
210     else 
211         written += len;
212   }
213
214   dsi->write_count += written;
215   dsi->in_write--;
216   return written;
217 }
218 #endif
219
220 /* 
221  * Return all bytes up to count from dsi->buffer if there are any buffered there
222  */
223 static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
224 {
225     size_t nbe = 0;
226
227     LOG(log_maxdebug, logtype_dsi, "from_buf: %u bytes", count);
228     
229     if (dsi->start) {        
230         nbe = dsi->eof - dsi->start;
231
232         if (nbe > 0) {
233            nbe = min((size_t)nbe, count);
234            memcpy(buf, dsi->start, nbe);
235            dsi->start += nbe;
236
237            if (dsi->eof == dsi->start) 
238                dsi->start = dsi->eof = dsi->buffer;
239
240         }
241     }
242     return nbe;
243 }
244
245 /*
246  * Get bytes from buffer dsi->buffer or read from socket
247  *
248  * 1. Check if there are bytes in the the dsi->buffer buffer.
249  * 2. Return bytes from (1) if yes.
250  *    Note: this may return fewer bytes then requested in count !!
251  * 3. If the buffer was empty, read from the socket.
252  */
253 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
254 {
255     ssize_t nbe;
256
257     LOG(log_maxdebug, logtype_dsi, "buf_read: %u bytes", count);
258
259     if (!count)
260         return 0;
261
262     nbe = from_buf(dsi, buf, count); /* 1. */
263     if (nbe)
264         return nbe;             /* 2. */
265   
266     return readt(dsi->socket, buf, count, 0, 1); /* 3. */
267 }
268
269 /*
270  * Essentially a loop around buf_read() to ensure "length" bytes are read
271  * from dsi->buffer and/or the socket.
272  */
273 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
274 {
275   size_t stored;
276   ssize_t len;
277
278   LOG(log_maxdebug, logtype_dsi, "dsi_stream_read: %u bytes", length);
279
280   stored = 0;
281   while (stored < length) {
282     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
283     if (len == -1 && (errno == EINTR || errno == EAGAIN)) {
284       LOG(log_debug, logtype_dsi, "dsi_stream_read: select read loop");
285       continue;
286     } else if (len > 0) {
287       stored += len;
288     } else { /* eof or error */
289       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
290       if (len || stored || dsi->read_count) {
291           if (! (dsi->flags & DSI_DISCONNECTED))
292               LOG(log_error, logtype_dsi, "dsi_stream_read: len:%d, %s",
293                   dsi->socket, len, (len < 0) ? strerror(errno) : "unexpected EOF");
294       }
295       break;
296     }
297   }
298
299   dsi->read_count += stored;
300   return stored;
301 }
302
303 /*
304  * Get "length" bytes from buffer and/or socket. In order to avoid frequent small reads
305  * this tries to read larger chunks (65536 bytes) into a buffer.
306  */
307 static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t length)
308 {
309   size_t len;
310   size_t buflen;
311
312   LOG(log_maxdebug, logtype_dsi, "dsi_buffered_stream_read: %u bytes", length);
313   
314   len = from_buf(dsi, data, length); /* read from buffer dsi->buffer */
315   dsi->read_count += len;
316   if (len == length) {          /* got enough bytes from there ? */
317       return len;               /* yes */
318   }
319
320   /* fill the buffer with 65536 bytes or until buffer is full */
321   buflen = min(65536, dsi->end - dsi->eof);
322   if (buflen > 0) {
323       ssize_t ret;
324       ret = read(dsi->socket, dsi->eof, buflen);
325       if (ret > 0)
326           dsi->eof += ret;
327   }
328
329   /* now get the remaining data */
330   len += dsi_stream_read(dsi, data + len, length - len);
331   return len;
332 }
333
334 /* ---------------------------------------
335 */
336 static void block_sig(DSI *dsi)
337 {
338   dsi->in_write++;
339 }
340
341 /* ---------------------------------------
342 */
343 static void unblock_sig(DSI *dsi)
344 {
345   dsi->in_write--;
346 }
347
348 /* ---------------------------------------
349  * write data. 0 on failure. this assumes that dsi_len will never
350  * cause an overflow in the data buffer. 
351  */
352 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
353 {
354   char block[DSI_BLOCKSIZ];
355   struct iovec iov[2];
356   size_t towrite;
357   ssize_t len;
358
359   LOG(log_maxdebug, logtype_dsi, "dsi_stream_send: %u bytes",
360       length ? length : sizeof(block));
361
362   block[0] = dsi->header.dsi_flags;
363   block[1] = dsi->header.dsi_command;
364   memcpy(block + 2, &dsi->header.dsi_requestID, 
365          sizeof(dsi->header.dsi_requestID));
366   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
367   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
368   memcpy(block + 12, &dsi->header.dsi_reserved,
369          sizeof(dsi->header.dsi_reserved));
370
371   if (!length) { /* just write the header */
372     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
373     return length; /* really 0 on failure, 1 on success */
374   }
375   
376   /* block signals */
377   block_sig(dsi);
378   iov[0].iov_base = block;
379   iov[0].iov_len = sizeof(block);
380   iov[1].iov_base = buf;
381   iov[1].iov_len = length;
382   
383   towrite = sizeof(block) + length;
384   dsi->write_count += towrite;
385   while (towrite > 0) {
386     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
387         !len)
388       continue;
389     
390     if ((size_t)len == towrite) /* wrote everything out */
391       break;
392     else if (len < 0) { /* error */
393       if (errno == EAGAIN || errno == EWOULDBLOCK) {
394           if (!dsi_peek(dsi)) {
395               continue;
396           }
397       }
398       LOG(log_error, logtype_dsi, "dsi_stream_send: %s", strerror(errno));
399       unblock_sig(dsi);
400       return 0;
401     }
402     
403     towrite -= len;
404     if (towrite > length) { /* skip part of header */
405       iov[0].iov_base = (char *) iov[0].iov_base + len;
406       iov[0].iov_len -= len;
407     } else { /* skip to data */
408       if (iov[0].iov_len) {
409         len -= iov[0].iov_len;
410         iov[0].iov_len = 0;
411       }
412       iov[1].iov_base = (char *) iov[1].iov_base + len;
413       iov[1].iov_len -= len;
414     }
415   }
416   
417   unblock_sig(dsi);
418   return 1;
419 }
420
421
422 /* ---------------------------------------
423  * read data. function on success. 0 on failure. data length gets
424  * stored in length variable. this should really use size_t's, but
425  * that would require changes elsewhere. */
426 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
427                        size_t *rlength)
428 {
429   char block[DSI_BLOCKSIZ];
430
431   LOG(log_maxdebug, logtype_dsi, "dsi_stream_receive: %u bytes", ilength);
432
433   /* read in the header */
434   if (dsi_buffered_stream_read(dsi, (u_int8_t *)block, sizeof(block)) != sizeof(block)) 
435     return 0;
436
437   dsi->header.dsi_flags = block[0];
438   dsi->header.dsi_command = block[1];
439   /* FIXME, not the right place, 
440      but we get a server disconnect without reason in the log
441   */
442   if (!block[1]) {
443       LOG(log_error, logtype_dsi, "dsi_stream_receive: invalid packet, fatal");
444       return 0;
445   }
446
447   memcpy(&dsi->header.dsi_requestID, block + 2, 
448          sizeof(dsi->header.dsi_requestID));
449   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
450   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
451   memcpy(&dsi->header.dsi_reserved, block + 12,
452          sizeof(dsi->header.dsi_reserved));
453   dsi->clientID = ntohs(dsi->header.dsi_requestID);
454   
455   /* make sure we don't over-write our buffers. */
456   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
457   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
458     return 0;
459
460   return block[1];
461 }