]> arthur.barton.de Git - netatalk.git/blob - etc/cnid_dbd/comm.c
Remove SO_RCVTIMEO/SO_SNDTIMEO, switch to non-blocking IO
[netatalk.git] / etc / cnid_dbd / comm.c
1 /*
2  * Copyright (C) Joerg Lenneis 2003
3  * Copyright (C) Frank Lahm 2010
4  *
5  * All Rights Reserved.  See COPYING.
6  */
7
8 #ifdef HAVE_CONFIG_H
9 #include "config.h"
10 #endif
11
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <errno.h>
16 #include <unistd.h>
17 #include <sys/param.h>
18 #include <sys/types.h>
19 #include <sys/time.h>
20 #include <sys/uio.h>
21 #include <sys/socket.h>
22 #include <sys/select.h>
23 #include <assert.h>
24 #include <time.h>
25
26 #include <atalk/logger.h>
27 #include <atalk/util.h>
28 #include <atalk/cnid_dbd_private.h>
29
30 #include "db_param.h"
31 #include "usockfd.h"
32 #include "comm.h"
33
34 /* Length of the space taken up by a padded control message of length len */
35 #ifndef CMSG_SPACE
36 #define CMSG_SPACE(len) (__CMSG_ALIGN(sizeof(struct cmsghdr)) + __CMSG_ALIGN(len))
37 #endif
38
39
40 struct connection {
41     time_t tm;                    /* When respawned last */
42     int    fd;
43 };
44
45 static int   control_fd;
46 static int   cur_fd;
47 static struct connection *fd_table;
48 static int  fd_table_size;
49 static int  fds_in_use = 0;
50
51
52 static void invalidate_fd(int fd)
53 {
54     int i;
55
56     if (fd == control_fd)
57         return;
58     for (i = 0; i != fds_in_use; i++)
59         if (fd_table[i].fd == fd)
60             break;
61
62     assert(i < fds_in_use);
63
64     fds_in_use--;
65     fd_table[i] = fd_table[fds_in_use];
66     fd_table[fds_in_use].fd = -1;
67     close(fd);
68     return;
69 }
70
71 /*!
72  * non-blocking drop-in replacement for read with timeout using select
73  *
74  * @param socket   (r)  must be nonblocking !
75  * @param data     (rw) buffer for the read data
76  * @param lenght   (r)  how many bytes to read
77  * @param timeout  (r)  number of seconds to try reading
78  *
79  * @returns number of bytes actually read or -1 on fatal error
80  */
81 static ssize_t readt(int socket, void *data, const size_t length, int timeout)
82 {
83     size_t stored;
84     ssize_t len;
85     struct timeval tv;
86     fd_set rfds;
87     int ret;
88
89     stored = 0;
90
91     while (stored < length) {
92         len = read(socket, (u_int8_t *) data + stored, length - stored);
93         if (len == -1) {
94             switch (errno) {
95             case EINTR:
96                 continue;
97             case EAGAIN:
98                 tv.tv_usec = 0;
99                 tv.tv_sec  = timeout;
100
101                 FD_ZERO(&rfds);
102                 FD_SET(socket, &rfds);
103                 while ((ret = select(socket + 1, &rfds, NULL, NULL, &tv)) < 1) {
104                     switch (ret) {
105                     case 0:
106                         LOG(log_warning, logtype_cnid, "select timeout 1s");
107                         return stored;
108                     default: /* -1 */
109                         LOG(log_error, logtype_cnid, "select: %s", strerror(errno));
110                         return -1;
111                     }
112                 }
113                 continue;
114             }
115             LOG(log_error, logtype_cnid, "read: %s", strerror(errno));
116             return -1;
117         }
118         else if (len > 0)
119             stored += len;
120         else
121             break;
122     }
123     return stored;
124 }
125
126
127 static int recv_cred(int fd)
128 {
129     int ret;
130     struct msghdr msgh;
131     struct iovec iov[1];
132     struct cmsghdr *cmsgp = NULL;
133     char buf[CMSG_SPACE(sizeof(int))];
134     char dbuf[80];
135
136     memset(&msgh,0,sizeof(msgh));
137     memset(buf,0,sizeof(buf));
138
139     msgh.msg_name = NULL;
140     msgh.msg_namelen = 0;
141
142     msgh.msg_iov = iov;
143     msgh.msg_iovlen = 1;
144
145     iov[0].iov_base = dbuf;
146     iov[0].iov_len = sizeof(dbuf);
147
148     msgh.msg_control = buf;
149     msgh.msg_controllen = sizeof(buf);
150
151     do  {
152         ret = recvmsg(fd ,&msgh,0);
153     } while ( ret == -1 && errno == EINTR );
154
155     if ( ret == -1 ) {
156         return -1;
157     }
158
159     for ( cmsgp = CMSG_FIRSTHDR(&msgh); cmsgp != NULL; cmsgp = CMSG_NXTHDR(&msgh,cmsgp) ) {
160         if ( cmsgp->cmsg_level == SOL_SOCKET && cmsgp->cmsg_type == SCM_RIGHTS ) {
161             return *(int *) CMSG_DATA(cmsgp);
162         }
163     }
164
165     if ( ret == sizeof (int) )
166         errno = *(int *)dbuf; /* Rcvd errno */
167     else
168         errno = ENOENT;    /* Default errno */
169
170     return -1;
171 }
172
173 /*
174  *  Check for client requests. We keep up to fd_table_size open descriptors in
175  *  fd_table. If the table is full and we get a new descriptor via
176  *  control_fd, we close a random decriptor in the table to make space. The
177  *  affected client will automatically reconnect. For an EOF (descriptor is
178  *  closed by the client, so a read here returns 0) comm_rcv will take care of
179  *  things and clean up fd_table. The same happens for any read/write errors.
180  */
181
182 static int check_fd(time_t timeout, const sigset_t *sigmask, time_t *now)
183 {
184     int fd;
185     fd_set readfds;
186     struct timespec tv;
187     int ret;
188     int i;
189     int maxfd = control_fd;
190     time_t t;
191
192     FD_ZERO(&readfds);
193     FD_SET(control_fd, &readfds);
194
195     for (i = 0; i != fds_in_use; i++) {
196         FD_SET(fd_table[i].fd, &readfds);
197         if (maxfd < fd_table[i].fd)
198             maxfd = fd_table[i].fd;
199     }
200
201     tv.tv_nsec = 0;
202     tv.tv_sec  = timeout;
203     if ((ret = pselect(maxfd + 1, &readfds, NULL, NULL, &tv, sigmask)) < 0) {
204         if (errno == EINTR)
205             return 0;
206         LOG(log_error, logtype_cnid, "error in select: %s",strerror(errno));
207         return -1;
208     }
209
210     time(&t);
211     if (now)
212         *now = t;
213
214     if (!ret)
215         return 0;
216
217
218     if (FD_ISSET(control_fd, &readfds)) {
219         int    l = 0;
220
221         fd = recv_cred(control_fd);
222         if (fd < 0) {
223             return -1;
224         }
225         if (fds_in_use < fd_table_size) {
226             fd_table[fds_in_use].fd = fd;
227             fd_table[fds_in_use].tm = t;
228             fds_in_use++;
229         } else {
230             time_t older = t;
231
232             for (i = 0; i != fds_in_use; i++) {
233                 if (older <= fd_table[i].tm) {
234                     older = fd_table[i].tm;
235                     l = i;
236                 }
237             }
238             close(fd_table[l].fd);
239             fd_table[l].fd = fd;
240             fd_table[l].tm = t;
241         }
242         return 0;
243     }
244
245     for (i = 0; i != fds_in_use; i++) {
246         if (FD_ISSET(fd_table[i].fd, &readfds)) {
247             fd_table[i].tm = t;
248             return fd_table[i].fd;
249         }
250     }
251     /* We should never get here */
252     return 0;
253 }
254
255 int comm_init(struct db_param *dbp, int ctrlfd, int clntfd)
256 {
257     int i;
258
259     fds_in_use = 0;
260     fd_table_size = dbp->fd_table_size;
261
262     if ((fd_table = malloc(fd_table_size * sizeof(struct connection))) == NULL) {
263         LOG(log_error, logtype_cnid, "Out of memory");
264         return -1;
265     }
266     for (i = 0; i != fd_table_size; i++)
267         fd_table[i].fd = -1;
268     /* from dup2 */
269     control_fd = ctrlfd;
270 #if 0
271     int b = 1;
272     /* this one dump core in recvmsg, great */
273     if ( setsockopt(control_fd, SOL_SOCKET, SO_PASSCRED, &b, sizeof (b)) < 0) {
274         LOG(log_error, logtype_cnid, "setsockopt SO_PASSCRED %s",  strerror(errno));
275         return -1;
276     }
277 #endif
278     /* push the first client fd */
279     fd_table[fds_in_use].fd = clntfd;
280     fds_in_use++;
281
282     return 0;
283 }
284
285 /* ------------
286    nbe of clients
287 */
288 int comm_nbe(void)
289 {
290     return fds_in_use;
291 }
292
293 /* ------------ */
294 int comm_rcv(struct cnid_dbd_rqst *rqst, time_t timeout, const sigset_t *sigmask, time_t *now)
295 {
296     char *nametmp;
297     int b;
298
299     if ((cur_fd = check_fd(timeout, sigmask, now)) < 0)
300         return -1;
301
302     if (!cur_fd)
303         return 0;
304
305     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got data on fd %u", cur_fd);
306
307     if (setnonblock(cur_fd, 1) != 0) {
308         LOG(log_error, logtype_cnid, "comm_rcv: setnonblock: %s", strerror(errno));
309         return -1;
310     }
311
312     nametmp = rqst->name;
313     if ((b = readt(cur_fd, rqst, sizeof(struct cnid_dbd_rqst), CNID_DBD_TIMEOUT))
314         != sizeof(struct cnid_dbd_rqst)) {
315         if (b)
316             LOG(log_error, logtype_cnid, "error reading message header: %s", strerror(errno));
317         invalidate_fd(cur_fd);
318         rqst->name = nametmp;
319         return 0;
320     }
321     rqst->name = nametmp;
322     if (rqst->namelen && readt(cur_fd, rqst->name, rqst->namelen, CNID_DBD_TIMEOUT)
323         != rqst->namelen) {
324         LOG(log_error, logtype_cnid, "error reading message name: %s", strerror(errno));
325         invalidate_fd(cur_fd);
326         return 0;
327     }
328     /* We set this to make life easier for logging. None of the other stuff
329        needs zero terminated strings. */
330     rqst->name[rqst->namelen] = '\0';
331
332     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got %u bytes", b + rqst->namelen);
333
334     return 1;
335 }
336
337 /* ------------ */
338 #define USE_WRITEV
339 int comm_snd(struct cnid_dbd_rply *rply)
340 {
341 #ifdef USE_WRITEV
342     struct iovec iov[2];
343     size_t towrite;
344 #endif
345
346     if (!rply->namelen) {
347         if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
348             LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
349             invalidate_fd(cur_fd);
350             return 0;
351         }
352         return 1;
353     }
354 #ifdef USE_WRITEV
355
356     iov[0].iov_base = rply;
357     iov[0].iov_len = sizeof(struct cnid_dbd_rply);
358     iov[1].iov_base = rply->name;
359     iov[1].iov_len = rply->namelen;
360     towrite = sizeof(struct cnid_dbd_rply) +rply->namelen;
361
362     if (writev(cur_fd, iov, 2) != towrite) {
363         LOG(log_error, logtype_cnid, "error writing message : %s", strerror(errno));
364         invalidate_fd(cur_fd);
365         return 0;
366     }
367 #else
368     if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
369         LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
370         invalidate_fd(cur_fd);
371         return 0;
372     }
373     if (write(cur_fd, rply->name, rply->namelen) != rply->namelen) {
374         LOG(log_error, logtype_cnid, "error writing message name: %s", strerror(errno));
375         invalidate_fd(cur_fd);
376         return 0;
377     }
378 #endif
379     return 1;
380 }
381
382