]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
1) try a better workaround for deadlocks when both the server and the client are...
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * $Id: dsi_stream.c,v 1.17 2009-10-25 06:13:11 didg Exp $
3  *
4  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
5  * All rights reserved. See COPYRIGHT.
6  *
7  * this file provides the following functions:
8  * dsi_stream_write:    just write a bunch of bytes.
9  * dsi_stream_read:     just read a bunch of bytes.
10  * dsi_stream_send:     send a DSI header + data.
11  * dsi_stream_receive:  read a DSI header + data.
12  */
13
14 #ifdef HAVE_CONFIG_H
15 #include "config.h"
16 #endif /* HAVE_CONFIG_H */
17
18 #define USE_WRITEV
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #ifdef HAVE_UNISTD_H
24 #include <unistd.h>
25 #endif
26
27 #include <string.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31
32 #ifdef USE_WRITEV
33 #include <sys/uio.h>
34 #endif
35
36 #include <atalk/logger.h>
37
38 #include <atalk/dsi.h>
39 #include <netatalk/endian.h>
40 #include <sys/ioctl.h> 
41
42 #define min(a,b)  ((a) < (b) ? (a) : (b))
43
44 #ifndef MSG_MORE
45 #define MSG_MORE 0x8000
46 #endif
47
48 #ifndef MSG_DONTWAIT
49 #define MSG_DONTWAIT 0x40
50 #endif
51
52 /* ------------------------- 
53  * we don't use a circular buffer.
54 */
55 static void dsi_init_buffer(DSI *dsi)
56 {
57     if (!dsi->buffer) {
58         /* XXX config options */
59         dsi->maxsize = 6 * dsi->server_quantum;
60         if (!dsi->maxsize)
61             dsi->maxsize = 6 * DSI_SERVQUANT_DEF;
62         dsi->buffer = malloc(dsi->maxsize);
63         if (!dsi->buffer) {
64             return;
65         }
66         dsi->start = dsi->buffer;
67         dsi->eof = dsi->buffer;
68         dsi->end = dsi->buffer + dsi->maxsize;
69     }
70 }
71
72 /* ---------------------- 
73    afpd is sleeping too much while trying to send something.
74    May be there's no reader or the reader is also sleeping in write,
75    look if there's some data for us to read, hopefully it will wake up
76    the reader
77 */
78 static int dsi_buffer(DSI *dsi)
79 {
80     fd_set readfds, writefds;
81     int    len;
82     int    maxfd;
83     int adr;
84
85     /* non blocking mode */
86     adr = 1;
87     if (ioctl(dsi->socket, FIONBIO, &adr) < 0) {
88         /* can't do it! exit without error it will sleep to death below */
89         LOG(log_error, logtype_default, "dsi_buffer: ioctl non blocking mode %s", strerror(errno));
90         return 0;
91     }
92     
93     FD_ZERO(&readfds);
94     FD_ZERO(&writefds);
95     FD_SET( dsi->socket, &readfds);
96     FD_SET( dsi->socket, &writefds);
97     maxfd = dsi->socket +1;
98     while (1) {
99         FD_SET( dsi->socket, &readfds);
100         FD_SET( dsi->socket, &writefds);
101         if (select( maxfd, &readfds, &writefds, NULL, NULL) <= 0)
102             break;
103
104         if ( !FD_ISSET(dsi->socket, &readfds)) {
105             /* nothing waiting in the read queue */
106             break;
107         }
108         dsi_init_buffer(dsi);
109         len = dsi->end - dsi->eof;
110
111         if (len <= 0) {
112             /* ouch, our buffer is full ! 
113              * fall back to blocking IO 
114              * could block and disconnect but it's better than a cpu hog
115              */
116             break;
117         }
118
119         len = read(dsi->socket, dsi->eof, len);
120         if (len <= 0)
121             break;
122         dsi->eof += len;
123         if ( FD_ISSET(dsi->socket, &writefds)) {
124             /* we can write again at last */
125             break;
126         }
127     }
128     adr = 0;
129     if (ioctl(dsi->socket, FIONBIO, &adr) < 0) {
130         /* can't do it! afpd will fail very quickly */
131         LOG(log_error, logtype_default, "dsi_buffer: ioctl blocking mode %s", strerror(errno));
132         return -1;
133     }
134     return 0;
135 }
136
137 /* ------------------------------
138  * write raw data. return actual bytes read. checks against EINTR
139  * aren't necessary if all of the signals have SA_RESTART
140  * specified. */
141 ssize_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
142 {
143   size_t written;
144   ssize_t len;
145 #if 0
146   /* FIXME sometime it's slower */
147   unsigned int flags = (mode)?MSG_MORE:0;
148 #endif
149   unsigned int flags = 0;
150
151 #if 0
152   /* XXX there's no MSG_DONTWAIT in recv ?? so we have to play with ioctl
153   */ 
154   flags |= MSG_DONTWAIT;
155 #endif
156   
157   dsi->in_write++;
158   written = 0;
159   while (written < length) {
160     if ((-1 == (len = send(dsi->socket, (u_int8_t *) data + written,
161                       length - written, flags)) && errno == EINTR) ||
162         !len)
163       continue;
164
165     if (len < 0) {
166       if (errno == EAGAIN || errno == EWOULDBLOCK) {
167           if (mode == DSI_NOWAIT && written == 0) {
168               /* DSI_NOWAIT is used by attention
169                  give up in this case.
170               */
171               return -1;
172           }
173           if (dsi_buffer(dsi)) {
174               /* can't go back to blocking mode, exit, the next read
175                  will return with an error and afpd will die.
176               */
177               break;
178           }
179           continue;
180       }
181       LOG(log_error, logtype_default, "dsi_stream_write: %s", strerror(errno));
182       break;
183     }
184     else {
185         written += len;
186     }
187   }
188
189   dsi->write_count += written;
190   dsi->in_write--;
191   return written;
192 }
193
194 /* ---------------------------------
195 */
196 static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
197 {
198     size_t nbe = 0;
199     
200     if (dsi->start) {        
201         nbe = dsi->eof - dsi->start;
202
203         if (nbe > 0) {
204            nbe = min((size_t)nbe, count);
205            memcpy(buf, dsi->start, nbe);
206            dsi->start += nbe;
207
208            if (dsi->eof == dsi->start) 
209                dsi->start = dsi->eof = dsi->buffer;
210
211         }
212     }
213     return nbe;
214 }
215
216 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
217 {
218     ssize_t nbe;
219     
220     if (!count)
221         return 0;
222
223     nbe = from_buf(dsi, buf, count);
224     if (nbe)
225         return nbe;
226   
227     return read(dsi->socket, buf, count);
228
229 }
230
231 /* ---------------------------------------
232  * read raw data. return actual bytes read. this will wait until 
233  * it gets length bytes 
234  */
235 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
236 {
237   size_t stored;
238   ssize_t len;
239   
240   stored = 0;
241   while (stored < length) {
242     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
243     if (len == -1 && errno == EINTR)
244       continue;
245     else if (len > 0)
246       stored += len;
247     else { /* eof or error */
248       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
249       if (len || stored || dsi->read_count) {
250           LOG(log_error, logtype_default, "dsi_stream_read(%d): %s", len, (len < 0)?strerror(errno):"unexpected EOF");
251       }
252       break;
253     }
254   }
255
256   dsi->read_count += stored;
257   return stored;
258 }
259
260 /* ---------------------------------------
261  * read raw data. return actual bytes read. this will wait until 
262  * it gets length bytes 
263  */
264 static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t length)
265 {
266   size_t len;
267   size_t buflen;
268   
269   dsi_init_buffer(dsi);
270   len = from_buf(dsi, data, length);
271   dsi->read_count += len;
272   if (len == length) {
273       return len;
274   }
275   
276   buflen = min(8192, dsi->end - dsi->eof);
277   if (buflen > 0) {
278       ssize_t ret;
279       ret = read(dsi->socket, dsi->eof, buflen);
280       if (ret > 0)
281           dsi->eof += ret;
282   }
283   return dsi_stream_read(dsi, data, length -len);
284 }
285
286 /* ---------------------------------------
287 */
288 void dsi_sleep(DSI *dsi, const int state)
289 {
290     dsi->asleep = state;
291 }
292
293 /* ---------------------------------------
294 */
295 static void block_sig(DSI *dsi)
296 {
297   dsi->in_write++;
298 }
299
300 /* ---------------------------------------
301 */
302 static void unblock_sig(DSI *dsi)
303 {
304   dsi->in_write--;
305 }
306
307 /* ---------------------------------------
308  * write data. 0 on failure. this assumes that dsi_len will never
309  * cause an overflow in the data buffer. 
310  */
311 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
312 {
313   char block[DSI_BLOCKSIZ];
314 #ifdef USE_WRITEV
315   struct iovec iov[2];
316   size_t towrite;
317   ssize_t len;
318 #endif /* USE_WRITEV */
319
320   block[0] = dsi->header.dsi_flags;
321   block[1] = dsi->header.dsi_command;
322   memcpy(block + 2, &dsi->header.dsi_requestID, 
323          sizeof(dsi->header.dsi_requestID));
324   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
325   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
326   memcpy(block + 12, &dsi->header.dsi_reserved,
327          sizeof(dsi->header.dsi_reserved));
328
329   if (!length) { /* just write the header */
330     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
331     return length; /* really 0 on failure, 1 on success */
332   }
333   
334   /* block signals */
335   block_sig(dsi);
336 #ifdef USE_WRITEV
337   iov[0].iov_base = block;
338   iov[0].iov_len = sizeof(block);
339   iov[1].iov_base = buf;
340   iov[1].iov_len = length;
341   
342   towrite = sizeof(block) + length;
343   dsi->write_count += towrite;
344   while (towrite > 0) {
345     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
346         !len)
347       continue;
348     
349     if ((size_t)len == towrite) /* wrote everything out */
350       break;
351     else if (len < 0) { /* error */
352       if (errno == EAGAIN || errno == EWOULDBLOCK) {
353           if (!dsi_buffer(dsi)) {
354               continue;
355           }
356       }
357       LOG(log_error, logtype_default, "dsi_stream_send: %s", strerror(errno));
358       unblock_sig(dsi);
359       return 0;
360     }
361     
362     towrite -= len;
363     if (towrite > length) { /* skip part of header */
364       iov[0].iov_base = (char *) iov[0].iov_base + len;
365       iov[0].iov_len -= len;
366     } else { /* skip to data */
367       if (iov[0].iov_len) {
368         len -= iov[0].iov_len;
369         iov[0].iov_len = 0;
370       }
371       iov[1].iov_base = (char *) iov[1].iov_base + len;
372       iov[1].iov_len -= len;
373     }
374   }
375   
376 #else /* USE_WRITEV */
377   /* write the header then data */
378   if ((dsi_stream_write(dsi, block, sizeof(block), 1) != sizeof(block)) ||
379             (dsi_stream_write(dsi, buf, length, 0) != length)) {
380       unblock_sig(dsi);
381       return 0;
382   }
383 #endif /* USE_WRITEV */
384
385   unblock_sig(dsi);
386   return 1;
387 }
388
389
390 /* ---------------------------------------
391  * read data. function on success. 0 on failure. data length gets
392  * stored in length variable. this should really use size_t's, but
393  * that would require changes elsewhere. */
394 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
395                        size_t *rlength)
396 {
397   char block[DSI_BLOCKSIZ];
398
399   /* read in the header */
400   if (dsi_buffered_stream_read(dsi, (u_int8_t *)block, sizeof(block)) != sizeof(block)) 
401     return 0;
402
403   dsi->header.dsi_flags = block[0];
404   dsi->header.dsi_command = block[1];
405   /* FIXME, not the right place, 
406      but we get a server disconnect without reason in the log
407   */
408   if (!block[1]) {
409       LOG(log_error, logtype_default, "dsi_stream_receive: invalid packet, fatal");
410       return 0;
411   }
412
413   memcpy(&dsi->header.dsi_requestID, block + 2, 
414          sizeof(dsi->header.dsi_requestID));
415   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
416   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
417   memcpy(&dsi->header.dsi_reserved, block + 12,
418          sizeof(dsi->header.dsi_reserved));
419   dsi->clientID = ntohs(dsi->header.dsi_requestID);
420   
421   /* make sure we don't over-write our buffers. */
422   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
423   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
424     return 0;
425
426   return block[1];
427 }