]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
ac5bbd56e769bac87a5068beda9ae9301961f6d1
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * $Id: dsi_stream.c,v 1.11.6.4 2004-02-10 10:21:50 didg Exp $
3  *
4  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
5  * All rights reserved. See COPYRIGHT.
6  *
7  * this file provides the following functions:
8  * dsi_stream_write:    just write a bunch of bytes.
9  * dsi_stream_read:     just read a bunch of bytes.
10  * dsi_stream_send:     send a DSI header + data.
11  * dsi_stream_receive:  read a DSI header + data.
12  */
13
14 #ifdef HAVE_CONFIG_H
15 #include "config.h"
16 #endif /* HAVE_CONFIG_H */
17
18 #define USE_WRITEV
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #ifdef HAVE_UNISTD_H
24 #include <unistd.h>
25 #endif
26
27 #include <string.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31
32 #ifdef USE_WRITEV
33 #include <sys/uio.h>
34 #endif
35
36 #include <atalk/logger.h>
37
38 #include <atalk/dsi.h>
39 #include <netatalk/endian.h>
40
41 #define min(a,b)  ((a) < (b) ? (a) : (b))
42
43 #ifndef MSG_MORE
44 #define MSG_MORE 0x8000
45 #endif
46
47 #ifndef MSG_DONTWAIT
48 #define MSG_DONTWAIT 0x40
49 #endif
50
51 /* ------------------------- 
52  * we don't use a circular buffer.
53 */
54 const void dsi_buffer(DSI *dsi)
55 {
56     fd_set readfds, writefds;
57     int    len;
58     int    maxfd;
59
60     FD_ZERO(&readfds);
61     FD_ZERO(&writefds);
62     FD_SET( dsi->socket, &readfds);
63     FD_SET( dsi->socket, &writefds);
64     maxfd = dsi->socket +1;
65     if (select( maxfd, &readfds, &writefds, NULL, NULL) <= 0)
66         return;
67
68     if ( !FD_ISSET(dsi->socket, &readfds)) {
69         /* nothing waiting in the queue */
70         return;
71     }
72     if (!dsi->buffer) {
73         /* XXX config options */
74         dsi->maxsize = 6 * dsi->server_quantum;
75         if (!dsi->maxsize)
76             dsi->maxsize = 6 * DSI_SERVQUANT_DEF;
77         dsi->buffer = malloc(dsi->maxsize);
78         if (!dsi->buffer) {
79            /* fall back to blocking IO */
80             dsi_block(dsi, 0);
81             return;
82         }
83         dsi->start = dsi->buffer;
84         dsi->eof = dsi->buffer;
85         dsi->end = dsi->buffer + dsi->maxsize;
86     }
87     len = dsi->end - dsi->eof;
88
89     if (len <= 0) {
90         /* ouch, our buffer is full ! 
91          * fall back to blocking IO 
92          * could block and disconnect but it's better than a cpu hog
93         */
94         dsi_block(dsi, 0);
95         return;
96     }
97
98     len = read(dsi->socket, dsi->eof, len);
99     if (len <= 0)
100         return;
101     dsi->eof += len;
102 }
103
104 /* ------------------------------
105  * write raw data. return actual bytes read. checks against EINTR
106  * aren't necessary if all of the signals have SA_RESTART
107  * specified. */
108 size_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode)
109 {
110   size_t written;
111   ssize_t len;
112 #if 0
113   /* FIXME sometime it's slower */
114   unsigned int flags = (mode)?MSG_MORE:0;
115 #endif
116   unsigned int flags = 0;
117
118 #if 0
119   /* XXX there's no MSG_DONTWAIT in recv ?? so we have to play with ioctl
120   */ 
121   if (dsi->noblocking) {
122       flags |= MSG_DONTWAIT;
123   }
124 #endif
125   
126   written = 0;
127   while (written < length) {
128     if ((-1 == (len = send(dsi->socket, (u_int8_t *) data + written,
129                       length - written, flags)) && errno == EINTR) ||
130         !len)
131       continue;
132
133     if (len < 0) {
134       if (dsi->noblocking && errno ==  EAGAIN) {
135          /* non blocking mode but will block 
136           * read data in input queue.
137           * 
138          */
139          dsi_buffer(dsi);
140       }
141       else {
142           LOG(log_error, logtype_default, "dsi_stream_write: %s", strerror(errno));
143           break;
144       }
145     }
146     else {
147         written += len;
148     }
149   }
150
151   dsi->write_count += written;
152   return written;
153 }
154
155 /* ---------------------------------
156 */
157 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
158 {
159     ssize_t nbe = 0;
160     ssize_t ret;
161     
162     if (!count)
163         return 0;
164         
165     if (dsi->start) {        
166         nbe = dsi->eof - dsi->start;
167
168         if (nbe > 0) {
169            nbe = min((size_t)nbe, count);
170            memcpy(buf, dsi->start, nbe);
171            dsi->start += nbe;
172
173            if (dsi->eof == dsi->start) 
174                dsi->start = dsi->eof = dsi->buffer;
175
176            if (nbe == count)
177                return nbe;
178            count -= nbe;
179            buf += nbe;
180         }
181         else 
182            nbe = 0;
183     }
184   
185     ret = read(dsi->socket, buf, count);
186     if (ret <= 0)
187         return ret;
188
189     return ret +nbe;
190 }
191
192 /* ---------------------------------------
193  * read raw data. return actual bytes read. this will wait until 
194  * it gets length bytes 
195  */
196 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
197 {
198   size_t stored;
199   ssize_t len;
200   
201   stored = 0;
202   while (stored < length) {
203     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
204     if (len == -1 && errno == EINTR)
205       continue;
206     else if (len > 0)
207       stored += len;
208     else { /* eof or error */
209       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
210       if (len || stored || dsi->read_count) {
211           LOG(log_error, logtype_default, "dsi_stream_read(%d): %s", len, (len < 0)?strerror(errno):"unexpected EOF");
212       }
213       break;
214     }
215   }
216
217   dsi->read_count += stored;
218   return stored;
219 }
220
221 /* ---------------------------------------
222 */
223 void dsi_sleep(DSI *dsi, const int state)
224 {
225     dsi->asleep = state;
226 }
227
228 /* ---------------------------------------
229 */
230 static void block_sig(DSI *dsi)
231 {
232   if (!dsi->sigblocked) sigprocmask(SIG_BLOCK, &dsi->sigblockset, &dsi->oldset);
233 }
234
235 /* ---------------------------------------
236 */
237 static void unblock_sig(DSI *dsi)
238 {
239   if (!dsi->sigblocked) sigprocmask(SIG_SETMASK, &dsi->oldset, NULL);
240 }
241
242 /* ---------------------------------------
243  * write data. 0 on failure. this assumes that dsi_len will never
244  * cause an overflow in the data buffer. 
245  */
246 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
247 {
248   char block[DSI_BLOCKSIZ];
249 #ifdef USE_WRITEV
250   struct iovec iov[2];
251   size_t towrite;
252   ssize_t len;
253 #endif /* USE_WRITEV */
254
255   block[0] = dsi->header.dsi_flags;
256   block[1] = dsi->header.dsi_command;
257   memcpy(block + 2, &dsi->header.dsi_requestID, 
258          sizeof(dsi->header.dsi_requestID));
259   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
260   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
261   memcpy(block + 12, &dsi->header.dsi_reserved,
262          sizeof(dsi->header.dsi_reserved));
263
264   /* block signals */
265   block_sig(dsi);
266
267   if (!length) { /* just write the header */
268     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
269     unblock_sig(dsi);
270     return length; /* really 0 on failure, 1 on success */
271   }
272   
273 #ifdef USE_WRITEV
274   iov[0].iov_base = block;
275   iov[0].iov_len = sizeof(block);
276   iov[1].iov_base = buf;
277   iov[1].iov_len = length;
278   
279   towrite = sizeof(block) + length;
280   dsi->write_count += towrite;
281   while (towrite > 0) {
282     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
283         !len)
284       continue;
285     
286     if (len == towrite) /* wrote everything out */
287       break;
288     else if (len < 0) { /* error */
289       LOG(log_error, logtype_default, "dsi_stream_send: %s", strerror(errno));
290       unblock_sig(dsi);
291       return 0;
292     }
293     
294     towrite -= len;
295     if (towrite > length) { /* skip part of header */
296       iov[0].iov_base = (char *) iov[0].iov_base + len;
297       iov[0].iov_len -= len;
298     } else { /* skip to data */
299       if (iov[0].iov_len) {
300         len -= iov[0].iov_len;
301         iov[0].iov_len = 0;
302       }
303       iov[1].iov_base = (char *) iov[1].iov_base + len;
304       iov[1].iov_len -= len;
305     }
306   }
307   
308 #else /* USE_WRITEV */
309   /* write the header then data */
310   if ((dsi_stream_write(dsi, block, sizeof(block), 1) != sizeof(block)) ||
311             (dsi_stream_write(dsi, buf, length, 0) != length)) {
312       unblock_sig(dsi);
313       return 0;
314   }
315 #endif /* USE_WRITEV */
316
317   unblock_sig(dsi);
318   return 1;
319 }
320
321
322 /* ---------------------------------------
323  * read data. function on success. 0 on failure. data length gets
324  * stored in length variable. this should really use size_t's, but
325  * that would require changes elsewhere. */
326 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
327                        size_t *rlength)
328 {
329   char block[DSI_BLOCKSIZ];
330
331   /* read in the header */
332   if (dsi_stream_read(dsi, block, sizeof(block)) != sizeof(block)) 
333     return 0;
334
335   dsi->header.dsi_flags = block[0];
336   dsi->header.dsi_command = block[1];
337   /* FIXME, not the right place, 
338      but we get a server disconnect without reason in the log
339   */
340   if (!block[1]) {
341       LOG(log_error, logtype_default, "dsi_stream_receive: invalid packet, fatal");
342       return 0;
343   }
344
345   memcpy(&dsi->header.dsi_requestID, block + 2, 
346          sizeof(dsi->header.dsi_requestID));
347   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
348   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
349   memcpy(&dsi->header.dsi_reserved, block + 12,
350          sizeof(dsi->header.dsi_reserved));
351   dsi->clientID = ntohs(dsi->header.dsi_requestID);
352   
353   /* make sure we don't over-write our buffers. */
354   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
355   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
356     return 0;
357
358   return block[1];
359 }