]> arthur.barton.de Git - netatalk.git/blob - libatalk/dsi/dsi_stream.c
- merge branch-netatalk-afp-3x-dev, HEAD was tagged before
[netatalk.git] / libatalk / dsi / dsi_stream.c
1 /*
2  * $Id: dsi_stream.c,v 1.12 2005-04-28 20:50:02 bfernhomberg Exp $
3  *
4  * Copyright (c) 1998 Adrian Sun (asun@zoology.washington.edu)
5  * All rights reserved. See COPYRIGHT.
6  *
7  * this file provides the following functions:
8  * dsi_stream_write:    just write a bunch of bytes.
9  * dsi_stream_read:     just read a bunch of bytes.
10  * dsi_stream_send:     send a DSI header + data.
11  * dsi_stream_receive:  read a DSI header + data.
12  */
13
14 #ifdef HAVE_CONFIG_H
15 #include "config.h"
16 #endif /* HAVE_CONFIG_H */
17
18 #define USE_WRITEV
19
20 #include <stdio.h>
21 #include <stdlib.h>
22
23 #ifdef HAVE_UNISTD_H
24 #include <unistd.h>
25 #endif
26
27 #include <string.h>
28 #include <errno.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31
32 #ifdef USE_WRITEV
33 #include <sys/uio.h>
34 #endif
35
36 #include <atalk/logger.h>
37
38 #include <atalk/dsi.h>
39 #include <netatalk/endian.h>
40
41 #define min(a,b)  ((a) < (b) ? (a) : (b))
42
43 #ifndef MSG_MORE
44 #define MSG_MORE 0x8000
45 #endif
46
47 #ifndef MSG_DONTWAIT
48 #define MSG_DONTWAIT 0x40
49 #endif
50
51 /* ------------------------- 
52  * we don't use a circular buffer.
53 */
54 void dsi_buffer(DSI *dsi)
55 {
56     fd_set readfds, writefds;
57     int    len;
58     int    maxfd;
59
60     FD_ZERO(&readfds);
61     FD_ZERO(&writefds);
62     FD_SET( dsi->socket, &readfds);
63     FD_SET( dsi->socket, &writefds);
64     maxfd = dsi->socket +1;
65     while (1) {
66         FD_SET( dsi->socket, &readfds);
67         FD_SET( dsi->socket, &writefds);
68         if (select( maxfd, &readfds, &writefds, NULL, NULL) <= 0)
69             return;
70
71         if ( !FD_ISSET(dsi->socket, &readfds)) {
72             /* nothing waiting in the read queue */
73             return;
74         }
75         if (!dsi->buffer) {
76             /* XXX config options */
77             dsi->maxsize = 6 * dsi->server_quantum;
78             if (!dsi->maxsize)
79                 dsi->maxsize = 6 * DSI_SERVQUANT_DEF;
80             dsi->buffer = malloc(dsi->maxsize);
81             if (!dsi->buffer) {
82                 /* fall back to blocking IO */
83                 dsi_block(dsi, 0);
84                 return;
85             }
86             dsi->start = dsi->buffer;
87             dsi->eof = dsi->buffer;
88             dsi->end = dsi->buffer + dsi->maxsize;
89         }
90         len = dsi->end - dsi->eof;
91
92         if (len <= 0) {
93             /* ouch, our buffer is full ! 
94              * fall back to blocking IO 
95              * could block and disconnect but it's better than a cpu hog
96              */
97             dsi_block(dsi, 0);
98             return;
99         }
100
101         len = read(dsi->socket, dsi->eof, len);
102         if (len <= 0)
103             return;
104         dsi->eof += len;
105         if ( FD_ISSET(dsi->socket, &writefds)) {
106             return;
107         }
108     }
109 }
110
111 /* ------------------------------
112  * write raw data. return actual bytes read. checks against EINTR
113  * aren't necessary if all of the signals have SA_RESTART
114  * specified. */
115 size_t dsi_stream_write(DSI *dsi, void *data, const size_t length, int mode _U_)
116 {
117   size_t written;
118   ssize_t len;
119 #if 0
120   /* FIXME sometime it's slower */
121   unsigned int flags = (mode)?MSG_MORE:0;
122 #endif
123   unsigned int flags = 0;
124
125 #if 0
126   /* XXX there's no MSG_DONTWAIT in recv ?? so we have to play with ioctl
127   */ 
128   if (dsi->noblocking) {
129       flags |= MSG_DONTWAIT;
130   }
131 #endif
132   
133   written = 0;
134   while (written < length) {
135     if ((-1 == (len = send(dsi->socket, (u_int8_t *) data + written,
136                       length - written, flags)) && errno == EINTR) ||
137         !len)
138       continue;
139
140     if (len < 0) {
141       if (dsi->noblocking && errno ==  EAGAIN) {
142          /* non blocking mode but will block 
143           * read data in input queue.
144           * 
145          */
146          dsi_buffer(dsi);
147       }
148       else {
149           LOG(log_error, logtype_default, "dsi_stream_write: %s", strerror(errno));
150           break;
151       }
152     }
153     else {
154         written += len;
155     }
156   }
157
158   dsi->write_count += written;
159   return written;
160 }
161
162 /* ---------------------------------
163 */
164 static ssize_t buf_read(DSI *dsi, u_int8_t *buf, size_t count)
165 {
166     ssize_t nbe = 0;
167     ssize_t ret;
168     
169     if (!count)
170         return 0;
171         
172     if (dsi->start) {        
173         nbe = dsi->eof - dsi->start;
174
175         if (nbe > 0) {
176            nbe = min((size_t)nbe, count);
177            memcpy(buf, dsi->start, nbe);
178            dsi->start += nbe;
179
180            if (dsi->eof == dsi->start) 
181                dsi->start = dsi->eof = dsi->buffer;
182
183            if (nbe == count)
184                return nbe;
185            count -= nbe;
186            buf += nbe;
187         }
188         else 
189            nbe = 0;
190     }
191   
192     ret = read(dsi->socket, buf, count);
193     if (ret <= 0)
194         return ret;
195
196     return ret +nbe;
197 }
198
199 /* ---------------------------------------
200  * read raw data. return actual bytes read. this will wait until 
201  * it gets length bytes 
202  */
203 size_t dsi_stream_read(DSI *dsi, void *data, const size_t length)
204 {
205   size_t stored;
206   ssize_t len;
207   
208   stored = 0;
209   while (stored < length) {
210     len = buf_read(dsi, (u_int8_t *) data + stored, length - stored);
211     if (len == -1 && errno == EINTR)
212       continue;
213     else if (len > 0)
214       stored += len;
215     else { /* eof or error */
216       /* don't log EOF error if it's just after connect (OSX 10.3 probe) */
217       if (len || stored || dsi->read_count) {
218           LOG(log_error, logtype_default, "dsi_stream_read(%d): %s", len, (len < 0)?strerror(errno):"unexpected EOF");
219       }
220       break;
221     }
222   }
223
224   dsi->read_count += stored;
225   return stored;
226 }
227
228 /* ---------------------------------------
229 */
230 void dsi_sleep(DSI *dsi, const int state)
231 {
232     dsi->asleep = state;
233 }
234
235 /* ---------------------------------------
236 */
237 static void block_sig(DSI *dsi)
238 {
239   if (!dsi->sigblocked) sigprocmask(SIG_BLOCK, &dsi->sigblockset, &dsi->oldset);
240 }
241
242 /* ---------------------------------------
243 */
244 static void unblock_sig(DSI *dsi)
245 {
246   if (!dsi->sigblocked) sigprocmask(SIG_SETMASK, &dsi->oldset, NULL);
247 }
248
249 /* ---------------------------------------
250  * write data. 0 on failure. this assumes that dsi_len will never
251  * cause an overflow in the data buffer. 
252  */
253 int dsi_stream_send(DSI *dsi, void *buf, size_t length)
254 {
255   char block[DSI_BLOCKSIZ];
256 #ifdef USE_WRITEV
257   struct iovec iov[2];
258   size_t towrite;
259   ssize_t len;
260 #endif /* USE_WRITEV */
261
262   block[0] = dsi->header.dsi_flags;
263   block[1] = dsi->header.dsi_command;
264   memcpy(block + 2, &dsi->header.dsi_requestID, 
265          sizeof(dsi->header.dsi_requestID));
266   memcpy(block + 4, &dsi->header.dsi_code, sizeof(dsi->header.dsi_code));
267   memcpy(block + 8, &dsi->header.dsi_len, sizeof(dsi->header.dsi_len));
268   memcpy(block + 12, &dsi->header.dsi_reserved,
269          sizeof(dsi->header.dsi_reserved));
270
271   if (!length) { /* just write the header */
272     length = (dsi_stream_write(dsi, block, sizeof(block), 0) == sizeof(block));
273     return length; /* really 0 on failure, 1 on success */
274   }
275   
276   /* block signals */
277   block_sig(dsi);
278 #ifdef USE_WRITEV
279   iov[0].iov_base = block;
280   iov[0].iov_len = sizeof(block);
281   iov[1].iov_base = buf;
282   iov[1].iov_len = length;
283   
284   towrite = sizeof(block) + length;
285   dsi->write_count += towrite;
286   while (towrite > 0) {
287     if (((len = writev(dsi->socket, iov, 2)) == -1 && errno == EINTR) || 
288         !len)
289       continue;
290     
291     if (len == towrite) /* wrote everything out */
292       break;
293     else if (len < 0) { /* error */
294       LOG(log_error, logtype_default, "dsi_stream_send: %s", strerror(errno));
295       unblock_sig(dsi);
296       return 0;
297     }
298     
299     towrite -= len;
300     if (towrite > length) { /* skip part of header */
301       iov[0].iov_base = (char *) iov[0].iov_base + len;
302       iov[0].iov_len -= len;
303     } else { /* skip to data */
304       if (iov[0].iov_len) {
305         len -= iov[0].iov_len;
306         iov[0].iov_len = 0;
307       }
308       iov[1].iov_base = (char *) iov[1].iov_base + len;
309       iov[1].iov_len -= len;
310     }
311   }
312   
313 #else /* USE_WRITEV */
314   /* write the header then data */
315   if ((dsi_stream_write(dsi, block, sizeof(block), 1) != sizeof(block)) ||
316             (dsi_stream_write(dsi, buf, length, 0) != length)) {
317       unblock_sig(dsi);
318       return 0;
319   }
320 #endif /* USE_WRITEV */
321
322   unblock_sig(dsi);
323   return 1;
324 }
325
326
327 /* ---------------------------------------
328  * read data. function on success. 0 on failure. data length gets
329  * stored in length variable. this should really use size_t's, but
330  * that would require changes elsewhere. */
331 int dsi_stream_receive(DSI *dsi, void *buf, const size_t ilength,
332                        size_t *rlength)
333 {
334   char block[DSI_BLOCKSIZ];
335
336   /* read in the header */
337   if (dsi_stream_read(dsi, block, sizeof(block)) != sizeof(block)) 
338     return 0;
339
340   dsi->header.dsi_flags = block[0];
341   dsi->header.dsi_command = block[1];
342   /* FIXME, not the right place, 
343      but we get a server disconnect without reason in the log
344   */
345   if (!block[1]) {
346       LOG(log_error, logtype_default, "dsi_stream_receive: invalid packet, fatal");
347       return 0;
348   }
349
350   memcpy(&dsi->header.dsi_requestID, block + 2, 
351          sizeof(dsi->header.dsi_requestID));
352   memcpy(&dsi->header.dsi_code, block + 4, sizeof(dsi->header.dsi_code));
353   memcpy(&dsi->header.dsi_len, block + 8, sizeof(dsi->header.dsi_len));
354   memcpy(&dsi->header.dsi_reserved, block + 12,
355          sizeof(dsi->header.dsi_reserved));
356   dsi->clientID = ntohs(dsi->header.dsi_requestID);
357   
358   /* make sure we don't over-write our buffers. */
359   *rlength = min(ntohl(dsi->header.dsi_len), ilength);
360   if (dsi_stream_read(dsi, buf, *rlength) != *rlength) 
361     return 0;
362
363   return block[1];
364 }