]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
2f85ac318d225bd043504b0701e0c77061ae6d1d
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * $Id: dsi_stream.c,v 1.20 2009-10-26 12:35:56 franklahm Exp $
3  *
4  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
5  * All rights reserved. See COPYRIGHT.
6  *
7  * this file provides the following functions:
8  * dsi_stream_write:    just write a bunch of bytes.
9  * dsi_stream_read:     just read a bunch of bytes.
10  * dsi_stream_send:     send a DSI header + data.
11  * dsi_stream_receive:  read a DSI header + data.
12  */
13
14 #ifdef HAVE_CONFIG_H
15 #include "config.h"
16 #endif /* HAVE_CONFIG_H */
17
18 #define USE_WRITEV
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #ifdef HAVE_UNISTD_H
24 #include <unistd.h>
25 #endif
26
27 #include <string.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31
32 #ifdef USE_WRITEV
33 #include <sys/uio.h>
34 #endif
35
36 #include <atalk/logger.h>
37 #include <atalk/dsi.h>
38 #include <netatalk/endian.h>
39 #include <atalk/util.h>
40
41 #define min(a,b)  ((a) < (b) ? (a) : (b))
42
43 #ifndef MSG_MORE
44 #define MSG_MORE 0x8000
45 #endif
46
47 #ifndef MSG_DONTWAIT
48 #define MSG_DONTWAIT 0x40
49 #endif
50
51 /* ------------------------- 
52  * we don't use a circular buffer.
53 */
54 static void dsi_init_buffer(DSI *dsi)
55 {
56     if (!dsi->buffer) {
57         /* XXX config options */
58         dsi->maxsize = 6 * dsi->server_quantum;
59         if (!dsi->maxsize)
60             dsi->maxsize = 6 * DSI_SERVQUANT_DEF;
61         dsi->buffer = malloc(dsi->maxsize);
62         if (!dsi->buffer) {
63             return;
64         }
65         dsi->start = dsi->buffer;
66         dsi->eof = dsi->buffer;
67         dsi->end = dsi->buffer + dsi->maxsize;
68     }
69 }
70
71 /* ---------------------- 
72    afpd is sleeping too much while trying to send something.
73    May be there's no reader or the reader is also sleeping in write,
74    look if there's some data for us to read, hopefully it will wake up
75    the reader
76 */
77 static int dsi_buffer(DSI *dsi)
78 {
79     fd_set readfds, writefds;
80     int    len;
81     int    maxfd;
82
83     /* non blocking mode */
84     if (setnonblock(dsi->socket, 1) < 0) {
85         /* can't do it! exit without error it will sleep to death below */
86         LOG(log_error, logtype_dsi, "dsi_buffer: ioctl non blocking mode %s", strerror(errno));
87         return 0;
88     }
89     
90     FD_ZERO(&readfds);
91     FD_ZERO(&writefds);
92     FD_SET( dsi->socket, &readfds);
93     FD_SET( dsi->socket, &writefds);
94     maxfd = dsi->socket +1;
95     while (1) {
96         FD_SET( dsi->socket, &readfds);
97         FD_SET( dsi->socket, &writefds);
98         if (select( maxfd, &readfds, &writefds, NULL, NULL) <= 0)
99             break;
100
101         if ( !FD_ISSET(dsi->socket, &readfds)) {
102             /* nothing waiting in the read queue */
103             break;
104         }
105         dsi_init_buffer(dsi);
106         len = dsi->end - dsi->eof;
107
108         if (len <= 0) {
109             /* ouch, our buffer is full ! 
110              * fall back to blocking IO 
111              * could block and disconnect but it's better than a cpu hog
112              */
113             break;
114         }
115
116         len = read(dsi->socket, dsi->eof, len);
117         if (len <= 0)
118             break;
119         dsi->eof += len;
120         if ( FD_ISSET(dsi->socket, &writefds)) {
121             /* we can write again at last */
122             break;
123         }
124     }
125     if (setnonblock(dsi->socket, 0) < 0) {
126         /* can't do it! afpd will fail very quickly */
127         LOG(log_error, logtype_dsi, "dsi_buffer: ioctl blocking mode %s", strerror(errno));
128         return -1;
129     }
130     return 0;
131 }
132
133 /* ------------------------------
134  * write raw data. return actual bytes read. checks against EINTR
135  * aren't necessary if all of the signals have SA_RESTART
136  * specified. */
137 ssize_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
138 {
139   size_t written;
140   ssize_t len;
141 #if 0
142   /* FIXME sometime it's slower */
143   unsigned int flags = (mode)?MSG_MORE:0;
144 #endif
145   unsigned int flags = 0;
146
147 #if 0
148   /* XXX there's no MSG_DONTWAIT in recv ?? so we have to play with ioctl
149   */ 
150   flags |= MSG_DONTWAIT;
151 #endif
152   
153   dsi->in_write++;
154   written = 0;
155   while (written < length) {
156       len = send(dsi->socket, (u_int8_t *) data + written, length - written, flags);
157       if ((len == 0) || (len == -1 && errno == EINTR))
158           continue;
159
160     if (len < 0) {
161       if (errno == EAGAIN || errno == EWOULDBLOCK) {
162           if (mode == DSI_NOWAIT && written == 0) {
163               /* DSI_NOWAIT is used by attention give up in this case. */
164               return -1;
165           }
166           if (dsi_buffer(dsi)) {
167               /* can't go back to blocking mode, exit, the next read
168                  will return with an error and afpd will die.
169               */
170               break;
171           }
172           continue;
173       }
174       LOG(log_error, logtype_dsi, "dsi_stream_write: %s", strerror(errno));
175       break;
176     }
177     else {
178         written += len;
179     }
180   }
181
182   dsi->write_count += written;
183   dsi->in_write--;
184   return written;
185 }
186
187
188 /* ---------------------------------
189 */
190 #ifdef WITH_SENDFILE
191 ssize_t dsi_stream_read_file(DSI *dsi, int fromfd, off_t offset, const size_t length)
192 {
193   size_t written;
194   ssize_t len;
195
196   dsi->in_write++;
197   written = 0;
198
199   while (written < length) {
200     len = sys_sendfile(dsi->socket, fromfd, &offset, length - written);
201         
202     if (len < 0) {
203       if (errno == EINTR)
204           continue;
205       if (errno == EINVAL || errno == ENOSYS)
206           return -1;
207           
208       if (errno == EAGAIN || errno == EWOULDBLOCK) {
209           if (dsi_buffer(dsi)) {
210               /* can't go back to blocking mode, exit, the next read
211                  will return with an error and afpd will die.
212               */
213               break;
214           }
215           continue;
216       }
217       LOG(log_error, logtype_dsi, "dsi_stream_write: %s", strerror(errno));
218       break;
219     }
220     else if (!len) {
221         /* afpd is going to exit */
222         errno = EIO;
223         return -1; /* I think we're at EOF here... */
224     }
225     else 
226         written += len;
227   }
228
229   dsi->write_count += written;
230   dsi->in_write--;
231   return written;
232 }
233 #endif
234
235 /* 
236  * Return all bytes up to count from dsi->buffer if there are any buffered there
237  */
238 static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
239 {
240     size_t nbe = 0;
241     
242     if (dsi->start) {        
243         nbe = dsi->eof - dsi->start;
244
245         if (nbe > 0) {
246            nbe = min((size_t)nbe, count);
247            memcpy(buf, dsi->start, nbe);
248            dsi->start += nbe;
249
250            if (dsi->eof == dsi->start) 
251                dsi->start = dsi->eof = dsi->buffer;
252
253         }
254     }
255     return nbe;
256 }
257
258 /*
259  * Get bytes from buffer dsi->buffer or read from socket
260  *
261  * 1. Check if there are bytes in the the dsi->buffer buffer.
262  * 2. Return bytes from (1) if yes.
263  *    Note: this may return fewer bytes then requested in count !!
264  * 3. If the buffer was empty, read from the socket.
265  */
266 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
267 {
268     ssize_t nbe;
269     
270     if (!count)
271         return 0;
272
273     nbe = from_buf(dsi, buf, count); /* 1. */
274     if (nbe)
275         return nbe;             /* 2. */
276   
277     return read(dsi->socket, buf, count); /* 3. */
278 }
279
280 /*
281  * Essentially a loop around buf_read() to ensure "length" bytes are read
282  * from dsi->buffer and/or the socket.
283  */
284 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
285 {
286   size_t stored;
287   ssize_t len;
288   
289   stored = 0;
290   while (stored < length) {
291     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
292     if (len == -1 && errno == EINTR)
293       continue;
294     else if (len > 0)
295       stored += len;
296     else { /* eof or error */
297       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
298       if (len || stored || dsi->read_count) {
299           LOG(log_error, logtype_dsi, "dsi_stream_read(%d): %s", len, (len < 0)?strerror(errno):"unexpected EOF");
300       }
301       break;
302     }
303   }
304
305   dsi->read_count += stored;
306   return stored;
307 }
308
309 /*
310  * Get "length" bytes from buffer and/or socket. In order to avoid frequent small reads
311  * this tries to read larger chunks (8192 bytes) into a buffer.
312  */
313 static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t length)
314 {
315   size_t len;
316   size_t buflen;
317   
318   dsi_init_buffer(dsi);
319   len = from_buf(dsi, data, length); /* read from buffer dsi->buffer */
320   dsi->read_count += len;
321   if (len == length) {          /* got enough bytes from there ? */
322       return len;               /* yes */
323   }
324
325   /* fill the buffer with 8192 bytes or until buffer is full */
326   buflen = min(8192, dsi->end - dsi->eof);
327   if (buflen > 0) {
328       ssize_t ret;
329       ret = read(dsi->socket, dsi->eof, buflen);
330       if (ret > 0)
331           dsi->eof += ret;
332   }
333
334   /* now get the remaining data */
335   len += dsi_stream_read(dsi, data + len, length - len);
336   return len;
337 }
338
339 /* ---------------------------------------
340 */
341 void dsi_sleep(DSI *dsi, const int state)
342 {
343     dsi->asleep = state;
344 }
345
346 /* ---------------------------------------
347 */
348 static void block_sig(DSI *dsi)
349 {
350   dsi->in_write++;
351 }
352
353 /* ---------------------------------------
354 */
355 static void unblock_sig(DSI *dsi)
356 {
357   dsi->in_write--;
358 }
359
360 /* ---------------------------------------
361  * write data. 0 on failure. this assumes that dsi_len will never
362  * cause an overflow in the data buffer. 
363  */
364 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
365 {
366   char block[DSI_BLOCKSIZ];
367 #ifdef USE_WRITEV
368   struct iovec iov[2];
369   size_t towrite;
370   ssize_t len;
371 #endif /* USE_WRITEV */
372
373   block[0] = dsi->header.dsi_flags;
374   block[1] = dsi->header.dsi_command;
375   memcpy(block + 2, &dsi->header.dsi_requestID, 
376          sizeof(dsi->header.dsi_requestID));
377   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
378   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
379   memcpy(block + 12, &dsi->header.dsi_reserved,
380          sizeof(dsi->header.dsi_reserved));
381
382   if (!length) { /* just write the header */
383     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
384     return length; /* really 0 on failure, 1 on success */
385   }
386   
387   /* block signals */
388   block_sig(dsi);
389 #ifdef USE_WRITEV
390   iov[0].iov_base = block;
391   iov[0].iov_len = sizeof(block);
392   iov[1].iov_base = buf;
393   iov[1].iov_len = length;
394   
395   towrite = sizeof(block) + length;
396   dsi->write_count += towrite;
397   while (towrite > 0) {
398     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
399         !len)
400       continue;
401     
402     if ((size_t)len == towrite) /* wrote everything out */
403       break;
404     else if (len < 0) { /* error */
405       if (errno == EAGAIN || errno == EWOULDBLOCK) {
406           if (!dsi_buffer(dsi)) {
407               continue;
408           }
409       }
410       LOG(log_error, logtype_dsi, "dsi_stream_send: %s", strerror(errno));
411       unblock_sig(dsi);
412       return 0;
413     }
414     
415     towrite -= len;
416     if (towrite > length) { /* skip part of header */
417       iov[0].iov_base = (char *) iov[0].iov_base + len;
418       iov[0].iov_len -= len;
419     } else { /* skip to data */
420       if (iov[0].iov_len) {
421         len -= iov[0].iov_len;
422         iov[0].iov_len = 0;
423       }
424       iov[1].iov_base = (char *) iov[1].iov_base + len;
425       iov[1].iov_len -= len;
426     }
427   }
428   
429 #else /* USE_WRITEV */
430   /* write the header then data */
431   if ((dsi_stream_write(dsi, block, sizeof(block), 1) != sizeof(block)) ||
432             (dsi_stream_write(dsi, buf, length, 0) != length)) {
433       unblock_sig(dsi);
434       return 0;
435   }
436 #endif /* USE_WRITEV */
437
438   unblock_sig(dsi);
439   return 1;
440 }
441
442
443 /* ---------------------------------------
444  * read data. function on success. 0 on failure. data length gets
445  * stored in length variable. this should really use size_t's, but
446  * that would require changes elsewhere. */
447 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
448                        size_t *rlength)
449 {
450   char block[DSI_BLOCKSIZ];
451
452   /* read in the header */
453   if (dsi_buffered_stream_read(dsi, (u_int8_t *)block, sizeof(block)) != sizeof(block)) 
454     return 0;
455
456   dsi->header.dsi_flags = block[0];
457   dsi->header.dsi_command = block[1];
458   /* FIXME, not the right place, 
459      but we get a server disconnect without reason in the log
460   */
461   if (!block[1]) {
462       LOG(log_error, logtype_dsi, "dsi_stream_receive: invalid packet, fatal");
463       return 0;
464   }
465
466   memcpy(&dsi->header.dsi_requestID, block + 2, 
467          sizeof(dsi->header.dsi_requestID));
468   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
469   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
470   memcpy(&dsi->header.dsi_reserved, block + 12,
471          sizeof(dsi->header.dsi_reserved));
472   dsi->clientID = ntohs(dsi->header.dsi_requestID);
473   
474   /* make sure we don't over-write our buffers. */
475   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
476   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
477     return 0;
478
479   return block[1];
480 }