]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
73e8ba41949aabf411dfcb69f7af4f4368a58f3d
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * $Id: dsi_stream.c,v 1.20 2009-10-26 12:35:56 franklahm Exp $
3  *
4  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
5  * All rights reserved. See COPYRIGHT.
6  *
7  * this file provides the following functions:
8  * dsi_stream_write:    just write a bunch of bytes.
9  * dsi_stream_read:     just read a bunch of bytes.
10  * dsi_stream_send:     send a DSI header + data.
11  * dsi_stream_receive:  read a DSI header + data.
12  */
13
14 #ifdef HAVE_CONFIG_H
15 #include "config.h"
16 #endif /* HAVE_CONFIG_H */
17
18 #define USE_WRITEV
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #ifdef HAVE_UNISTD_H
24 #include <unistd.h>
25 #endif
26
27 #include <string.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31
32 #ifdef USE_WRITEV
33 #include <sys/uio.h>
34 #endif
35
36 #include <atalk/logger.h>
37 #include <atalk/dsi.h>
38 #include <netatalk/endian.h>
39 #include <atalk/util.h>
40
41 #define min(a,b)  ((a) < (b) ? (a) : (b))
42
43 #ifndef MSG_MORE
44 #define MSG_MORE 0x8000
45 #endif
46
47 #ifndef MSG_DONTWAIT
48 #define MSG_DONTWAIT 0x40
49 #endif
50
51 /* ------------------------- 
52  * we don't use a circular buffer.
53 */
54 static void dsi_init_buffer(DSI *dsi)
55 {
56     if (!dsi->buffer) {
57         /* XXX config options */
58         dsi->maxsize = 6 * dsi->server_quantum;
59         if (!dsi->maxsize)
60             dsi->maxsize = 6 * DSI_SERVQUANT_DEF;
61         dsi->buffer = malloc(dsi->maxsize);
62         if (!dsi->buffer) {
63             return;
64         }
65         dsi->start = dsi->buffer;
66         dsi->eof = dsi->buffer;
67         dsi->end = dsi->buffer + dsi->maxsize;
68     }
69 }
70
71 /* ---------------------- 
72    afpd is sleeping too much while trying to send something.
73    May be there's no reader or the reader is also sleeping in write,
74    look if there's some data for us to read, hopefully it will wake up
75    the reader
76 */
77 static int dsi_buffer(DSI *dsi)
78 {
79     fd_set readfds, writefds;
80     int    len;
81     int    maxfd;
82
83     /* non blocking mode */
84     if (setnonblock(dsi->socket, 1) < 0) {
85         /* can't do it! exit without error it will sleep to death below */
86         LOG(log_error, logtype_default, "dsi_buffer: ioctl non blocking mode %s", strerror(errno));
87         return 0;
88     }
89     
90     FD_ZERO(&readfds);
91     FD_ZERO(&writefds);
92     FD_SET( dsi->socket, &readfds);
93     FD_SET( dsi->socket, &writefds);
94     maxfd = dsi->socket +1;
95     while (1) {
96         FD_SET( dsi->socket, &readfds);
97         FD_SET( dsi->socket, &writefds);
98         if (select( maxfd, &readfds, &writefds, NULL, NULL) <= 0)
99             break;
100
101         if ( !FD_ISSET(dsi->socket, &readfds)) {
102             /* nothing waiting in the read queue */
103             break;
104         }
105         dsi_init_buffer(dsi);
106         len = dsi->end - dsi->eof;
107
108         if (len <= 0) {
109             /* ouch, our buffer is full ! 
110              * fall back to blocking IO 
111              * could block and disconnect but it's better than a cpu hog
112              */
113             break;
114         }
115
116         len = read(dsi->socket, dsi->eof, len);
117         if (len <= 0)
118             break;
119         dsi->eof += len;
120         if ( FD_ISSET(dsi->socket, &writefds)) {
121             /* we can write again at last */
122             break;
123         }
124     }
125     if (setnonblock(dsi->socket, 0) < 0) {
126         /* can't do it! afpd will fail very quickly */
127         LOG(log_error, logtype_default, "dsi_buffer: ioctl blocking mode %s", strerror(errno));
128         return -1;
129     }
130     return 0;
131 }
132
133 /* ------------------------------
134  * write raw data. return actual bytes read. checks against EINTR
135  * aren't necessary if all of the signals have SA_RESTART
136  * specified. */
137 ssize_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
138 {
139   size_t written;
140   ssize_t len;
141 #if 0
142   /* FIXME sometime it's slower */
143   unsigned int flags = (mode)?MSG_MORE:0;
144 #endif
145   unsigned int flags = 0;
146
147 #if 0
148   /* XXX there's no MSG_DONTWAIT in recv ?? so we have to play with ioctl
149   */ 
150   flags |= MSG_DONTWAIT;
151 #endif
152   
153   dsi->in_write++;
154   written = 0;
155   while (written < length) {
156     if ((-1 == (len = send(dsi->socket, (u_int8_t *) data + written,
157                       length - written, flags)) && errno == EINTR) ||
158         !len)
159       continue;
160
161     if (len < 0) {
162       if (errno == EAGAIN || errno == EWOULDBLOCK) {
163           if (mode == DSI_NOWAIT && written == 0) {
164               /* DSI_NOWAIT is used by attention
165                  give up in this case.
166               */
167               return -1;
168           }
169           if (dsi_buffer(dsi)) {
170               /* can't go back to blocking mode, exit, the next read
171                  will return with an error and afpd will die.
172               */
173               break;
174           }
175           continue;
176       }
177       LOG(log_error, logtype_default, "dsi_stream_write: %s", strerror(errno));
178       break;
179     }
180     else {
181         written += len;
182     }
183   }
184
185   dsi->write_count += written;
186   dsi->in_write--;
187   return written;
188 }
189
190
191 /* ---------------------------------
192 */
193 #ifdef WITH_SENDFILE
194 ssize_t dsi_stream_read_file(DSI *dsi, int fromfd, off_t offset, const size_t length)
195 {
196   size_t written;
197   ssize_t len;
198
199   dsi->in_write++;
200   written = 0;
201
202   while (written < length) {
203     len = sys_sendfile(dsi->socket, fromfd, &offset, length - written);
204         
205     if (len < 0) {
206       if (errno == EINTR)
207           continue;
208       if (errno == EINVAL || errno == ENOSYS)
209           return -1;
210           
211       if (errno == EAGAIN || errno == EWOULDBLOCK) {
212           if (dsi_buffer(dsi)) {
213               /* can't go back to blocking mode, exit, the next read
214                  will return with an error and afpd will die.
215               */
216               break;
217           }
218           continue;
219       }
220       LOG(log_error, logtype_default, "dsi_stream_write: %s", strerror(errno));
221       break;
222     }
223     else if (!len) {
224         /* afpd is going to exit */
225         errno = EIO;
226         return -1; /* I think we're at EOF here... */
227     }
228     else 
229         written += len;
230   }
231
232   dsi->write_count += written;
233   dsi->in_write--;
234   return written;
235 }
236 #endif
237
238 /* 
239  * Return all bytes up to count from dsi->buffer if there are any buffered there
240  */
241 static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
242 {
243     size_t nbe = 0;
244     
245     if (dsi->start) {        
246         nbe = dsi->eof - dsi->start;
247
248         if (nbe > 0) {
249            nbe = min((size_t)nbe, count);
250            memcpy(buf, dsi->start, nbe);
251            dsi->start += nbe;
252
253            if (dsi->eof == dsi->start) 
254                dsi->start = dsi->eof = dsi->buffer;
255
256         }
257     }
258     return nbe;
259 }
260
261 /*
262  * Get bytes from buffer dsi->buffer or read from socket
263  *
264  * 1. Check if there are bytes in the the dsi->buffer buffer.
265  * 2. Return bytes from (1) if yes.
266  *    Note: this may return fewer bytes then requested in count !!
267  * 3. If the buffer was empty, read from the socket.
268  */
269 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
270 {
271     ssize_t nbe;
272     
273     if (!count)
274         return 0;
275
276     nbe = from_buf(dsi, buf, count); /* 1. */
277     if (nbe)
278         return nbe;             /* 2. */
279   
280     return read(dsi->socket, buf, count); /* 3. */
281 }
282
283 /*
284  * Essentially a loop around buf_read() to ensure "length" bytes are read
285  * from dsi->buffer and/or the socket.
286  */
287 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
288 {
289   size_t stored;
290   ssize_t len;
291   
292   stored = 0;
293   while (stored < length) {
294     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
295     if (len == -1 && errno == EINTR)
296       continue;
297     else if (len > 0)
298       stored += len;
299     else { /* eof or error */
300       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
301       if (len || stored || dsi->read_count) {
302           LOG(log_error, logtype_default, "dsi_stream_read(%d): %s", len, (len < 0)?strerror(errno):"unexpected EOF");
303       }
304       break;
305     }
306   }
307
308   dsi->read_count += stored;
309   return stored;
310 }
311
312 /*
313  * Get "length" bytes from buffer and/or socket. In order to avoid frequent small reads
314  * this tries to read larger chunks (8192 bytes) into a buffer.
315  */
316 static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t length)
317 {
318   size_t len;
319   size_t buflen;
320   
321   dsi_init_buffer(dsi);
322   len = from_buf(dsi, data, length); /* read from buffer dsi->buffer */
323   dsi->read_count += len;
324   if (len == length) {          /* got enough bytes from there ? */
325       return len;               /* yes */
326   }
327
328   /* fill the buffer with 8192 bytes or until buffer is full */
329   buflen = min(8192, dsi->end - dsi->eof);
330   if (buflen > 0) {
331       ssize_t ret;
332       ret = read(dsi->socket, dsi->eof, buflen);
333       if (ret > 0)
334           dsi->eof += ret;
335   }
336
337   /* now get the remaining data */
338   len += dsi_stream_read(dsi, data + len, length - len);
339   return len;
340 }
341
342 /* ---------------------------------------
343 */
344 void dsi_sleep(DSI *dsi, const int state)
345 {
346     dsi->asleep = state;
347 }
348
349 /* ---------------------------------------
350 */
351 static void block_sig(DSI *dsi)
352 {
353   dsi->in_write++;
354 }
355
356 /* ---------------------------------------
357 */
358 static void unblock_sig(DSI *dsi)
359 {
360   dsi->in_write--;
361 }
362
363 /* ---------------------------------------
364  * write data. 0 on failure. this assumes that dsi_len will never
365  * cause an overflow in the data buffer. 
366  */
367 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
368 {
369   char block[DSI_BLOCKSIZ];
370 #ifdef USE_WRITEV
371   struct iovec iov[2];
372   size_t towrite;
373   ssize_t len;
374 #endif /* USE_WRITEV */
375
376   block[0] = dsi->header.dsi_flags;
377   block[1] = dsi->header.dsi_command;
378   memcpy(block + 2, &dsi->header.dsi_requestID, 
379          sizeof(dsi->header.dsi_requestID));
380   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
381   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
382   memcpy(block + 12, &dsi->header.dsi_reserved,
383          sizeof(dsi->header.dsi_reserved));
384
385   if (!length) { /* just write the header */
386     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
387     return length; /* really 0 on failure, 1 on success */
388   }
389   
390   /* block signals */
391   block_sig(dsi);
392 #ifdef USE_WRITEV
393   iov[0].iov_base = block;
394   iov[0].iov_len = sizeof(block);
395   iov[1].iov_base = buf;
396   iov[1].iov_len = length;
397   
398   towrite = sizeof(block) + length;
399   dsi->write_count += towrite;
400   while (towrite > 0) {
401     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
402         !len)
403       continue;
404     
405     if ((size_t)len == towrite) /* wrote everything out */
406       break;
407     else if (len < 0) { /* error */
408       if (errno == EAGAIN || errno == EWOULDBLOCK) {
409           if (!dsi_buffer(dsi)) {
410               continue;
411           }
412       }
413       LOG(log_error, logtype_default, "dsi_stream_send: %s", strerror(errno));
414       unblock_sig(dsi);
415       return 0;
416     }
417     
418     towrite -= len;
419     if (towrite > length) { /* skip part of header */
420       iov[0].iov_base = (char *) iov[0].iov_base + len;
421       iov[0].iov_len -= len;
422     } else { /* skip to data */
423       if (iov[0].iov_len) {
424         len -= iov[0].iov_len;
425         iov[0].iov_len = 0;
426       }
427       iov[1].iov_base = (char *) iov[1].iov_base + len;
428       iov[1].iov_len -= len;
429     }
430   }
431   
432 #else /* USE_WRITEV */
433   /* write the header then data */
434   if ((dsi_stream_write(dsi, block, sizeof(block), 1) != sizeof(block)) ||
435             (dsi_stream_write(dsi, buf, length, 0) != length)) {
436       unblock_sig(dsi);
437       return 0;
438   }
439 #endif /* USE_WRITEV */
440
441   unblock_sig(dsi);
442   return 1;
443 }
444
445
446 /* ---------------------------------------
447  * read data. function on success. 0 on failure. data length gets
448  * stored in length variable. this should really use size_t's, but
449  * that would require changes elsewhere. */
450 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
451                        size_t *rlength)
452 {
453   char block[DSI_BLOCKSIZ];
454
455   /* read in the header */
456   if (dsi_buffered_stream_read(dsi, (u_int8_t *)block, sizeof(block)) != sizeof(block)) 
457     return 0;
458
459   dsi->header.dsi_flags = block[0];
460   dsi->header.dsi_command = block[1];
461   /* FIXME, not the right place, 
462      but we get a server disconnect without reason in the log
463   */
464   if (!block[1]) {
465       LOG(log_error, logtype_default, "dsi_stream_receive: invalid packet, fatal");
466       return 0;
467   }
468
469   memcpy(&dsi->header.dsi_requestID, block + 2, 
470          sizeof(dsi->header.dsi_requestID));
471   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
472   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
473   memcpy(&dsi->header.dsi_reserved, block + 12,
474          sizeof(dsi->header.dsi_reserved));
475   dsi->clientID = ntohs(dsi->header.dsi_requestID);
476   
477   /* make sure we don't over-write our buffers. */
478   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
479   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
480     return 0;
481
482   return block[1];
483 }