]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
logout is not disconnect
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
3  * All rights reserved. See COPYRIGHT.
4  *
5  * this file provides the following functions:
6  * dsi_stream_write:    just write a bunch of bytes.
7  * dsi_stream_read:     just read a bunch of bytes.
8  * dsi_stream_send:     send a DSI header + data.
9  * dsi_stream_receive:  read a DSI header + data.
10  */
11
12 #ifdef HAVE_CONFIG_H
13 #include "config.h"
14 #endif /* HAVE_CONFIG_H */
15
16 #include <stdio.h>
17 #include <stdlib.h>
18
19 #ifdef HAVE_UNISTD_H
20 #include <unistd.h>
21 #endif
22
23 #include <string.h>
24 #include <errno.h>
25 #include <sys/types.h>
26 #include <sys/socket.h>
27 #include <sys/uio.h>
28
29 #include <atalk/logger.h>
30 #include <atalk/dsi.h>
31 #include <netatalk/endian.h>
32 #include <atalk/util.h>
33
34 #define min(a,b)  ((a) < (b) ? (a) : (b))
35
36 #ifndef MSG_MORE
37 #define MSG_MORE 0x8000
38 #endif
39
40 #ifndef MSG_DONTWAIT
41 #define MSG_DONTWAIT 0x40
42 #endif
43
44 /* ---------------------- 
45    afpd is sleeping too much while trying to send something.
46    May be there's no reader or the reader is also sleeping in write,
47    look if there's some data for us to read, hopefully it will wake up
48    the reader so we can write again.
49 */
50 static int dsi_peek(DSI *dsi)
51 {
52     fd_set readfds, writefds;
53     int    len;
54     int    maxfd;
55     int    ret;
56
57     LOG(log_debug, logtype_dsi, "dsi_peek");
58
59     FD_ZERO(&readfds);
60     FD_ZERO(&writefds);
61     FD_SET( dsi->socket, &readfds);
62     FD_SET( dsi->socket, &writefds);
63     maxfd = dsi->socket +1;
64
65     while (1) {
66         FD_SET( dsi->socket, &readfds);
67         FD_SET( dsi->socket, &writefds);
68
69         /* No timeout: if there's nothing to read nor nothing to write,
70          * we've got nothing to do at all */
71         if ((ret = select( maxfd, &readfds, &writefds, NULL, NULL)) <= 0) {
72             if (ret == -1 && errno == EINTR)
73                 /* we might have been interrupted by out timer, so restart select */
74                 continue;
75             /* give up */
76             LOG(log_error, logtype_dsi, "dsi_peek: unexpected select return: %d %s",
77                 ret, ret < 0 ? strerror(errno) : "");
78             return -1;
79         }
80
81         /* Check if there's sth to read, hopefully reading that will unblock the client */
82         if (FD_ISSET(dsi->socket, &readfds)) {
83             len = dsi->end - dsi->eof;
84
85             if (len <= 0) {
86                 /* ouch, our buffer is full ! fall back to blocking IO 
87                  * could block and disconnect but it's better than a cpu hog */
88                 LOG(log_warning, logtype_dsi, "dsi_peek: read buffer is full");
89                 break;
90             }
91
92             if ((len = read(dsi->socket, dsi->eof, len)) <= 0) {
93                 if (len == 0) {
94                     LOG(log_error, logtype_dsi, "dsi_peek: EOF");
95                     return -1;
96                 }
97                 LOG(log_error, logtype_dsi, "dsi_peek: read: %s", strerror(errno));
98                 if (errno == EAGAIN)
99                     continue;
100                 return -1;
101             }
102             LOG(log_debug, logtype_dsi, "dsi_peek: read %d bytes", len);
103
104             dsi->eof += len;
105         }
106
107         if (FD_ISSET(dsi->socket, &writefds)) {
108             /* we can write again */
109             LOG(log_debug, logtype_dsi, "dsi_peek: can write again");
110             break;
111         }
112     }
113
114     return 0;
115 }
116
117 /* ------------------------------
118  * write raw data. return actual bytes read. checks against EINTR
119  * aren't necessary if all of the signals have SA_RESTART
120  * specified. */
121 ssize_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
122 {
123   size_t written;
124   ssize_t len;
125   unsigned int flags = 0;
126
127   dsi->in_write++;
128   written = 0;
129
130   LOG(log_maxdebug, logtype_dsi, "dsi_stream_write: sending %u bytes", length);
131
132   while (written < length) {
133       len = send(dsi->socket, (u_int8_t *) data + written, length - written, flags);
134       if (len >= 0) {
135           written += len;
136           continue;
137       }
138
139       if (errno == EINTR)
140           continue;
141
142       if (errno == EAGAIN || errno == EWOULDBLOCK) {
143           if (mode == DSI_NOWAIT && written == 0) {
144               /* DSI_NOWAIT is used by attention give up in this case. */
145               written = -1;
146               goto exit;
147           }
148
149           /* Try to read sth. in order to break up possible deadlock */
150           if (dsi_peek(dsi) != 0) {
151               written = -1;
152               goto exit;
153           }
154           /* Now try writing again */
155           continue;
156       }
157
158       LOG(log_error, logtype_dsi, "dsi_stream_write: %s", strerror(errno));
159       written = -1;
160       goto exit;
161   }
162
163   dsi->write_count += written;
164
165 exit:
166   dsi->in_write--;
167   return written;
168 }
169
170
171 /* ---------------------------------
172 */
173 #ifdef WITH_SENDFILE
174 ssize_t dsi_stream_read_file(DSI *dsi, int fromfd, off_t offset, const size_t length)
175 {
176   size_t written;
177   ssize_t len;
178
179   LOG(log_maxdebug, logtype_dsi, "dsi_stream_read_file: sending %u bytes", length);
180
181   dsi->in_write++;
182   written = 0;
183
184   while (written < length) {
185     len = sys_sendfile(dsi->socket, fromfd, &offset, length - written);
186         
187     if (len < 0) {
188       if (errno == EINTR)
189           continue;
190       if (errno == EINVAL || errno == ENOSYS)
191           return -1;
192           
193       if (errno == EAGAIN || errno == EWOULDBLOCK) {
194           if (dsi_peek(dsi)) {
195               /* can't go back to blocking mode, exit, the next read
196                  will return with an error and afpd will die.
197               */
198               break;
199           }
200           continue;
201       }
202       LOG(log_error, logtype_dsi, "dsi_stream_read_file: %s", strerror(errno));
203       break;
204     }
205     else if (!len) {
206         /* afpd is going to exit */
207         errno = EIO;
208         return -1; /* I think we're at EOF here... */
209     }
210     else 
211         written += len;
212   }
213
214   dsi->write_count += written;
215   dsi->in_write--;
216   return written;
217 }
218 #endif
219
220 /* 
221  * Return all bytes up to count from dsi->buffer if there are any buffered there
222  */
223 static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
224 {
225     size_t nbe = 0;
226
227     LOG(log_maxdebug, logtype_dsi, "from_buf: %u bytes", count);
228     
229     if (dsi->start) {        
230         nbe = dsi->eof - dsi->start;
231
232         if (nbe > 0) {
233            nbe = min((size_t)nbe, count);
234            memcpy(buf, dsi->start, nbe);
235            dsi->start += nbe;
236
237            if (dsi->eof == dsi->start) 
238                dsi->start = dsi->eof = dsi->buffer;
239
240         }
241     }
242     return nbe;
243 }
244
245 /*
246  * Get bytes from buffer dsi->buffer or read from socket
247  *
248  * 1. Check if there are bytes in the the dsi->buffer buffer.
249  * 2. Return bytes from (1) if yes.
250  *    Note: this may return fewer bytes then requested in count !!
251  * 3. If the buffer was empty, read from the socket.
252  */
253 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
254 {
255     ssize_t nbe;
256
257     LOG(log_maxdebug, logtype_dsi, "buf_read: %u bytes", count);
258
259     if (!count)
260         return 0;
261
262     nbe = from_buf(dsi, buf, count); /* 1. */
263     if (nbe)
264         return nbe;             /* 2. */
265   
266     return readt(dsi->socket, buf, count, 0, 1); /* 3. */
267 }
268
269 /*
270  * Essentially a loop around buf_read() to ensure "length" bytes are read
271  * from dsi->buffer and/or the socket.
272  */
273 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
274 {
275   size_t stored;
276   ssize_t len;
277
278   LOG(log_maxdebug, logtype_dsi, "dsi_stream_read: %u bytes", length);
279
280   stored = 0;
281   while (stored < length) {
282     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
283     if (len == -1 && (errno == EINTR || errno == EAGAIN)) {
284       LOG(log_debug, logtype_dsi, "dsi_stream_read: select read loop");
285       continue;
286     } else if (len > 0) {
287       stored += len;
288     } else { /* eof or error */
289       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
290       if (len || stored || dsi->read_count) {
291           if (! (dsi->flags & DSI_DISCONNECTED)) {
292               LOG(log_error, logtype_dsi, "dsi_stream_read: len:%d, %s",
293                   len, (len < 0) ? strerror(errno) : "unexpected EOF");
294               AFP_PANIC("FIXME");
295           }
296       }
297       break;
298     }
299   }
300
301   dsi->read_count += stored;
302   return stored;
303 }
304
305 /*
306  * Get "length" bytes from buffer and/or socket. In order to avoid frequent small reads
307  * this tries to read larger chunks (65536 bytes) into a buffer.
308  */
309 static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t length)
310 {
311   size_t len;
312   size_t buflen;
313
314   LOG(log_maxdebug, logtype_dsi, "dsi_buffered_stream_read: %u bytes", length);
315   
316   len = from_buf(dsi, data, length); /* read from buffer dsi->buffer */
317   dsi->read_count += len;
318   if (len == length) {          /* got enough bytes from there ? */
319       return len;               /* yes */
320   }
321
322   /* fill the buffer with 65536 bytes or until buffer is full */
323   buflen = min(65536, dsi->end - dsi->eof);
324   if (buflen > 0) {
325       ssize_t ret;
326       ret = read(dsi->socket, dsi->eof, buflen);
327       if (ret > 0)
328           dsi->eof += ret;
329   }
330
331   /* now get the remaining data */
332   len += dsi_stream_read(dsi, data + len, length - len);
333   return len;
334 }
335
336 /* ---------------------------------------
337 */
338 static void block_sig(DSI *dsi)
339 {
340   dsi->in_write++;
341 }
342
343 /* ---------------------------------------
344 */
345 static void unblock_sig(DSI *dsi)
346 {
347   dsi->in_write--;
348 }
349
350 /* ---------------------------------------
351  * write data. 0 on failure. this assumes that dsi_len will never
352  * cause an overflow in the data buffer. 
353  */
354 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
355 {
356   char block[DSI_BLOCKSIZ];
357   struct iovec iov[2];
358   size_t towrite;
359   ssize_t len;
360
361   LOG(log_maxdebug, logtype_dsi, "dsi_stream_send: %u bytes",
362       length ? length : sizeof(block));
363
364   block[0] = dsi->header.dsi_flags;
365   block[1] = dsi->header.dsi_command;
366   memcpy(block + 2, &dsi->header.dsi_requestID, 
367          sizeof(dsi->header.dsi_requestID));
368   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
369   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
370   memcpy(block + 12, &dsi->header.dsi_reserved,
371          sizeof(dsi->header.dsi_reserved));
372
373   if (!length) { /* just write the header */
374     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
375     return length; /* really 0 on failure, 1 on success */
376   }
377   
378   /* block signals */
379   block_sig(dsi);
380   iov[0].iov_base = block;
381   iov[0].iov_len = sizeof(block);
382   iov[1].iov_base = buf;
383   iov[1].iov_len = length;
384   
385   towrite = sizeof(block) + length;
386   dsi->write_count += towrite;
387   while (towrite > 0) {
388     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
389         !len)
390       continue;
391     
392     if ((size_t)len == towrite) /* wrote everything out */
393       break;
394     else if (len < 0) { /* error */
395       if (errno == EAGAIN || errno == EWOULDBLOCK) {
396           if (!dsi_peek(dsi)) {
397               continue;
398           }
399       }
400       LOG(log_error, logtype_dsi, "dsi_stream_send: %s", strerror(errno));
401       unblock_sig(dsi);
402       return 0;
403     }
404     
405     towrite -= len;
406     if (towrite > length) { /* skip part of header */
407       iov[0].iov_base = (char *) iov[0].iov_base + len;
408       iov[0].iov_len -= len;
409     } else { /* skip to data */
410       if (iov[0].iov_len) {
411         len -= iov[0].iov_len;
412         iov[0].iov_len = 0;
413       }
414       iov[1].iov_base = (char *) iov[1].iov_base + len;
415       iov[1].iov_len -= len;
416     }
417   }
418   
419   unblock_sig(dsi);
420   return 1;
421 }
422
423
424 /* ---------------------------------------
425  * read data. function on success. 0 on failure. data length gets
426  * stored in length variable. this should really use size_t's, but
427  * that would require changes elsewhere. */
428 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
429                        size_t *rlength)
430 {
431   char block[DSI_BLOCKSIZ];
432
433   LOG(log_maxdebug, logtype_dsi, "dsi_stream_receive: %u bytes", ilength);
434
435   /* read in the header */
436   if (dsi_buffered_stream_read(dsi, (u_int8_t *)block, sizeof(block)) != sizeof(block)) 
437     return 0;
438
439   dsi->header.dsi_flags = block[0];
440   dsi->header.dsi_command = block[1];
441   /* FIXME, not the right place, 
442      but we get a server disconnect without reason in the log
443   */
444   if (!block[1]) {
445       LOG(log_error, logtype_dsi, "dsi_stream_receive: invalid packet, fatal");
446       return 0;
447   }
448
449   memcpy(&dsi->header.dsi_requestID, block + 2, 
450          sizeof(dsi->header.dsi_requestID));
451   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
452   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
453   memcpy(&dsi->header.dsi_reserved, block + 12,
454          sizeof(dsi->header.dsi_reserved));
455   dsi->clientID = ntohs(dsi->header.dsi_requestID);
456   
457   /* make sure we don't over-write our buffers. */
458   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
459   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
460     return 0;
461
462   return block[1];
463 }