]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
028364fadbb1685a8bc9ef45eaa32785bbb66796
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * $Id: dsi_stream.c,v 1.14 2009-10-19 11:01:51 didg Exp $
3  *
4  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
5  * All rights reserved. See COPYRIGHT.
6  *
7  * this file provides the following functions:
8  * dsi_stream_write:    just write a bunch of bytes.
9  * dsi_stream_read:     just read a bunch of bytes.
10  * dsi_stream_send:     send a DSI header + data.
11  * dsi_stream_receive:  read a DSI header + data.
12  */
13
14 #ifdef HAVE_CONFIG_H
15 #include "config.h"
16 #endif /* HAVE_CONFIG_H */
17
18 #define USE_WRITEV
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #ifdef HAVE_UNISTD_H
24 #include <unistd.h>
25 #endif
26
27 #include <string.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31
32 #ifdef USE_WRITEV
33 #include <sys/uio.h>
34 #endif
35
36 #include <atalk/logger.h>
37
38 #include <atalk/dsi.h>
39 #include <netatalk/endian.h>
40
41 #define min(a,b)  ((a) < (b) ? (a) : (b))
42
43 #ifndef MSG_MORE
44 #define MSG_MORE 0x8000
45 #endif
46
47 #ifndef MSG_DONTWAIT
48 #define MSG_DONTWAIT 0x40
49 #endif
50
51 /* ------------------------- 
52  * we don't use a circular buffer.
53 */
54 static void dsi_init_buffer(DSI *dsi)
55 {
56     if (!dsi->buffer) {
57         /* XXX config options */
58         dsi->maxsize = 6 * dsi->server_quantum;
59         if (!dsi->maxsize)
60             dsi->maxsize = 6 * DSI_SERVQUANT_DEF;
61         dsi->buffer = malloc(dsi->maxsize);
62         if (!dsi->buffer) {
63             return;
64         }
65         dsi->start = dsi->buffer;
66         dsi->eof = dsi->buffer;
67         dsi->end = dsi->buffer + dsi->maxsize;
68     }
69 }
70
71 /* ---------------------- */
72 static void dsi_buffer(DSI *dsi)
73 {
74     fd_set readfds, writefds;
75     int    len;
76     int    maxfd;
77
78     FD_ZERO(&readfds);
79     FD_ZERO(&writefds);
80     FD_SET( dsi->socket, &readfds);
81     FD_SET( dsi->socket, &writefds);
82     maxfd = dsi->socket +1;
83     while (1) {
84         FD_SET( dsi->socket, &readfds);
85         FD_SET( dsi->socket, &writefds);
86         if (select( maxfd, &readfds, &writefds, NULL, NULL) <= 0)
87             return;
88
89         if ( !FD_ISSET(dsi->socket, &readfds)) {
90             /* nothing waiting in the read queue */
91             return;
92         }
93         dsi_init_buffer(dsi);
94         len = dsi->end - dsi->eof;
95
96         if (len <= 0) {
97             /* ouch, our buffer is full ! 
98              * fall back to blocking IO 
99              * could block and disconnect but it's better than a cpu hog
100              */
101             dsi_block(dsi, 0);
102             return;
103         }
104
105         len = read(dsi->socket, dsi->eof, len);
106         if (len <= 0)
107             return;
108         dsi->eof += len;
109         if ( FD_ISSET(dsi->socket, &writefds)) {
110             return;
111         }
112     }
113 }
114
115 /* ------------------------------
116  * write raw data. return actual bytes read. checks against EINTR
117  * aren't necessary if all of the signals have SA_RESTART
118  * specified. */
119 size_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode _U_)
120 {
121   size_t written;
122   ssize_t len;
123 #if 0
124   /* FIXME sometime it's slower */
125   unsigned int flags = (mode)?MSG_MORE:0;
126 #endif
127   unsigned int flags = 0;
128
129 #if 0
130   /* XXX there's no MSG_DONTWAIT in recv ?? so we have to play with ioctl
131   */ 
132   if (dsi->noblocking) {
133       flags |= MSG_DONTWAIT;
134   }
135 #endif
136   
137   written = 0;
138   while (written < length) {
139     if ((-1 == (len = send(dsi->socket, (u_int8_t *) data + written,
140                       length - written, flags)) && errno == EINTR) ||
141         !len)
142       continue;
143
144     if (len < 0) {
145       if (dsi->noblocking && errno ==  EAGAIN) {
146          /* non blocking mode but will block 
147           * read data in input queue.
148           * 
149          */
150          dsi_buffer(dsi);
151       }
152       else {
153           LOG(log_error, logtype_default, "dsi_stream_write: %s", strerror(errno));
154           break;
155       }
156     }
157     else {
158         written += len;
159     }
160   }
161
162   dsi->write_count += written;
163   return written;
164 }
165
166 /* ---------------------------------
167 */
168 static size_t from_buf(DSI *dsi, u_int8_t *buf, size_t count)
169 {
170     size_t nbe = 0;
171     
172     if (dsi->start) {        
173         nbe = dsi->eof - dsi->start;
174
175         if (nbe > 0) {
176            nbe = min((size_t)nbe, count);
177            memcpy(buf, dsi->start, nbe);
178            dsi->start += nbe;
179
180            if (dsi->eof == dsi->start) 
181                dsi->start = dsi->eof = dsi->buffer;
182
183         }
184     }
185     return nbe;
186 }
187
188 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
189 {
190     ssize_t nbe;
191     
192     if (!count)
193         return 0;
194
195     nbe = from_buf(dsi, buf, count);
196     if (nbe)
197         return nbe;
198   
199     return read(dsi->socket, buf, count);
200
201 }
202
203 /* ---------------------------------------
204  * read raw data. return actual bytes read. this will wait until 
205  * it gets length bytes 
206  */
207 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
208 {
209   size_t stored;
210   ssize_t len;
211   
212   stored = 0;
213   while (stored < length) {
214     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
215     if (len == -1 && errno == EINTR)
216       continue;
217     else if (len > 0)
218       stored += len;
219     else { /* eof or error */
220       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
221       if (len || stored || dsi->read_count) {
222           LOG(log_error, logtype_default, "dsi_stream_read(%d): %s", len, (len < 0)?strerror(errno):"unexpected EOF");
223       }
224       break;
225     }
226   }
227
228   dsi->read_count += stored;
229   return stored;
230 }
231
232 /* ---------------------------------------
233  * read raw data. return actual bytes read. this will wait until 
234  * it gets length bytes 
235  */
236 static size_t dsi_buffered_stream_read(DSI *dsi, u_int8_t *data, const size_t length)
237 {
238   size_t len;
239   size_t buflen;
240   
241   dsi_init_buffer(dsi);
242   len = from_buf(dsi, data, length);
243   dsi->read_count += len;
244   if (len == length) {
245       return len;
246   }
247   
248   buflen = min(8192, dsi->end - dsi->eof);
249   if (buflen > 0) {
250       ssize_t ret;
251       ret = read(dsi->socket, dsi->eof, buflen);
252       if (ret > 0)
253           dsi->eof += ret;
254   }
255   return dsi_stream_read(dsi, data, length -len);
256 }
257
258 /* ---------------------------------------
259 */
260 void dsi_sleep(DSI *dsi, const int state)
261 {
262     dsi->asleep = state;
263 }
264
265 /* ---------------------------------------
266 */
267 static void block_sig(DSI *dsi)
268 {
269   if (!dsi->sigblocked) sigprocmask(SIG_BLOCK, &dsi->sigblockset, &dsi->oldset);
270 }
271
272 /* ---------------------------------------
273 */
274 static void unblock_sig(DSI *dsi)
275 {
276   if (!dsi->sigblocked) sigprocmask(SIG_SETMASK, &dsi->oldset, NULL);
277 }
278
279 /* ---------------------------------------
280  * write data. 0 on failure. this assumes that dsi_len will never
281  * cause an overflow in the data buffer. 
282  */
283 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
284 {
285   char block[DSI_BLOCKSIZ];
286 #ifdef USE_WRITEV
287   struct iovec iov[2];
288   size_t towrite;
289   ssize_t len;
290 #endif /* USE_WRITEV */
291
292   block[0] = dsi->header.dsi_flags;
293   block[1] = dsi->header.dsi_command;
294   memcpy(block + 2, &dsi->header.dsi_requestID, 
295          sizeof(dsi->header.dsi_requestID));
296   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
297   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
298   memcpy(block + 12, &dsi->header.dsi_reserved,
299          sizeof(dsi->header.dsi_reserved));
300
301   if (!length) { /* just write the header */
302     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
303     return length; /* really 0 on failure, 1 on success */
304   }
305   
306   /* block signals */
307   block_sig(dsi);
308 #ifdef USE_WRITEV
309   iov[0].iov_base = block;
310   iov[0].iov_len = sizeof(block);
311   iov[1].iov_base = buf;
312   iov[1].iov_len = length;
313   
314   towrite = sizeof(block) + length;
315   dsi->write_count += towrite;
316   while (towrite > 0) {
317     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
318         !len)
319       continue;
320     
321     if ((size_t)len == towrite) /* wrote everything out */
322       break;
323     else if (len < 0) { /* error */
324       LOG(log_error, logtype_default, "dsi_stream_send: %s", strerror(errno));
325       unblock_sig(dsi);
326       return 0;
327     }
328     
329     towrite -= len;
330     if (towrite > length) { /* skip part of header */
331       iov[0].iov_base = (char *) iov[0].iov_base + len;
332       iov[0].iov_len -= len;
333     } else { /* skip to data */
334       if (iov[0].iov_len) {
335         len -= iov[0].iov_len;
336         iov[0].iov_len = 0;
337       }
338       iov[1].iov_base = (char *) iov[1].iov_base + len;
339       iov[1].iov_len -= len;
340     }
341   }
342   
343 #else /* USE_WRITEV */
344   /* write the header then data */
345   if ((dsi_stream_write(dsi, block, sizeof(block), 1) != sizeof(block)) ||
346             (dsi_stream_write(dsi, buf, length, 0) != length)) {
347       unblock_sig(dsi);
348       return 0;
349   }
350 #endif /* USE_WRITEV */
351
352   unblock_sig(dsi);
353   return 1;
354 }
355
356
357 /* ---------------------------------------
358  * read data. function on success. 0 on failure. data length gets
359  * stored in length variable. this should really use size_t's, but
360  * that would require changes elsewhere. */
361 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
362                        size_t *rlength)
363 {
364   char block[DSI_BLOCKSIZ];
365
366   /* read in the header */
367   if (dsi_buffered_stream_read(dsi, block, sizeof(block)) != sizeof(block)) 
368     return 0;
369
370   dsi->header.dsi_flags = block[0];
371   dsi->header.dsi_command = block[1];
372   /* FIXME, not the right place, 
373      but we get a server disconnect without reason in the log
374   */
375   if (!block[1]) {
376       LOG(log_error, logtype_default, "dsi_stream_receive: invalid packet, fatal");
377       return 0;
378   }
379
380   memcpy(&dsi->header.dsi_requestID, block + 2, 
381          sizeof(dsi->header.dsi_requestID));
382   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
383   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
384   memcpy(&dsi->header.dsi_reserved, block + 12,
385          sizeof(dsi->header.dsi_reserved));
386   dsi->clientID = ntohs(dsi->header.dsi_requestID);
387   
388   /* make sure we don't over-write our buffers. */
389   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
390   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
391     return 0;
392
393   return block[1];
394 }