]> arthur.barton.de Git - netatalk.git/blob - etc/cnid_dbd/comm.c
Reapply ugly hack for sys/socket.h include
[netatalk.git] / etc / cnid_dbd / comm.c
1 /*
2  * Copyright (C) Joerg Lenneis 2003
3  * Copyright (C) Frank Lahm 2010
4  *
5  * All Rights Reserved.  See COPYING.
6  */
7
8 #ifdef HAVE_CONFIG_H
9 #include "config.h"
10 #endif
11
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <errno.h>
16 #include <unistd.h>
17 #include <sys/param.h>
18 #include <sys/types.h>
19 #include <sys/time.h>
20 #include <sys/uio.h>
21 #define _XPG4_2 1
22 #include <sys/socket.h>
23 #include <sys/select.h>
24 #include <assert.h>
25 #include <time.h>
26
27 #include <atalk/logger.h>
28 #include <atalk/util.h>
29 #include <atalk/cnid_dbd_private.h>
30
31 #include "db_param.h"
32 #include "usockfd.h"
33 #include "comm.h"
34
35 /* Length of the space taken up by a padded control message of length len */
36 #ifndef CMSG_SPACE
37 #define CMSG_SPACE(len) (__CMSG_ALIGN(sizeof(struct cmsghdr)) + __CMSG_ALIGN(len))
38 #endif
39
40
41 struct connection {
42     time_t tm;                    /* When respawned last */
43     int    fd;
44 };
45
46 static int   control_fd;
47 static int   cur_fd;
48 static struct connection *fd_table;
49 static int  fd_table_size;
50 static int  fds_in_use = 0;
51
52
53 static void invalidate_fd(int fd)
54 {
55     int i;
56
57     if (fd == control_fd)
58         return;
59     for (i = 0; i != fds_in_use; i++)
60         if (fd_table[i].fd == fd)
61             break;
62
63     assert(i < fds_in_use);
64
65     fds_in_use--;
66     fd_table[i] = fd_table[fds_in_use];
67     fd_table[fds_in_use].fd = -1;
68     close(fd);
69     return;
70 }
71
72 /*!
73  * non-blocking drop-in replacement for read with timeout using select
74  *
75  * @param socket   (r)  must be nonblocking !
76  * @param data     (rw) buffer for the read data
77  * @param lenght   (r)  how many bytes to read
78  * @param timeout  (r)  number of seconds to try reading
79  *
80  * @returns number of bytes actually read or -1 on fatal error
81  */
82 static ssize_t readt(int socket, void *data, const size_t length, int timeout)
83 {
84     size_t stored;
85     ssize_t len;
86     struct timeval tv;
87     fd_set rfds;
88     int ret;
89
90     stored = 0;
91
92     while (stored < length) {
93         len = read(socket, (u_int8_t *) data + stored, length - stored);
94         if (len == -1) {
95             switch (errno) {
96             case EINTR:
97                 continue;
98             case EAGAIN:
99                 tv.tv_usec = 0;
100                 tv.tv_sec  = timeout;
101
102                 FD_ZERO(&rfds);
103                 FD_SET(socket, &rfds);
104                 while ((ret = select(socket + 1, &rfds, NULL, NULL, &tv)) < 1) {
105                     switch (ret) {
106                     case 0:
107                         LOG(log_warning, logtype_cnid, "select timeout 1s");
108                         return stored;
109                     default: /* -1 */
110                         LOG(log_error, logtype_cnid, "select: %s", strerror(errno));
111                         return -1;
112                     }
113                 }
114                 continue;
115             }
116             LOG(log_error, logtype_cnid, "read: %s", strerror(errno));
117             return -1;
118         }
119         else if (len > 0)
120             stored += len;
121         else
122             break;
123     }
124     return stored;
125 }
126
127
128 static int recv_cred(int fd)
129 {
130     int ret;
131     struct msghdr msgh;
132     struct iovec iov[1];
133     struct cmsghdr *cmsgp = NULL;
134     char buf[CMSG_SPACE(sizeof(int))];
135     char dbuf[80];
136
137     memset(&msgh,0,sizeof(msgh));
138     memset(buf,0,sizeof(buf));
139
140     msgh.msg_name = NULL;
141     msgh.msg_namelen = 0;
142
143     msgh.msg_iov = iov;
144     msgh.msg_iovlen = 1;
145
146     iov[0].iov_base = dbuf;
147     iov[0].iov_len = sizeof(dbuf);
148
149     msgh.msg_control = buf;
150     msgh.msg_controllen = sizeof(buf);
151
152     do  {
153         ret = recvmsg(fd ,&msgh,0);
154     } while ( ret == -1 && errno == EINTR );
155
156     if ( ret == -1 ) {
157         return -1;
158     }
159
160     for ( cmsgp = CMSG_FIRSTHDR(&msgh); cmsgp != NULL; cmsgp = CMSG_NXTHDR(&msgh,cmsgp) ) {
161         if ( cmsgp->cmsg_level == SOL_SOCKET && cmsgp->cmsg_type == SCM_RIGHTS ) {
162             return *(int *) CMSG_DATA(cmsgp);
163         }
164     }
165
166     if ( ret == sizeof (int) )
167         errno = *(int *)dbuf; /* Rcvd errno */
168     else
169         errno = ENOENT;    /* Default errno */
170
171     return -1;
172 }
173
174 /*
175  *  Check for client requests. We keep up to fd_table_size open descriptors in
176  *  fd_table. If the table is full and we get a new descriptor via
177  *  control_fd, we close a random decriptor in the table to make space. The
178  *  affected client will automatically reconnect. For an EOF (descriptor is
179  *  closed by the client, so a read here returns 0) comm_rcv will take care of
180  *  things and clean up fd_table. The same happens for any read/write errors.
181  */
182
183 static int check_fd(time_t timeout, const sigset_t *sigmask, time_t *now)
184 {
185     int fd;
186     fd_set readfds;
187     struct timespec tv;
188     int ret;
189     int i;
190     int maxfd = control_fd;
191     time_t t;
192
193     FD_ZERO(&readfds);
194     FD_SET(control_fd, &readfds);
195
196     for (i = 0; i != fds_in_use; i++) {
197         FD_SET(fd_table[i].fd, &readfds);
198         if (maxfd < fd_table[i].fd)
199             maxfd = fd_table[i].fd;
200     }
201
202     tv.tv_nsec = 0;
203     tv.tv_sec  = timeout;
204     if ((ret = pselect(maxfd + 1, &readfds, NULL, NULL, &tv, sigmask)) < 0) {
205         if (errno == EINTR)
206             return 0;
207         LOG(log_error, logtype_cnid, "error in select: %s",strerror(errno));
208         return -1;
209     }
210
211     time(&t);
212     if (now)
213         *now = t;
214
215     if (!ret)
216         return 0;
217
218
219     if (FD_ISSET(control_fd, &readfds)) {
220         int    l = 0;
221
222         fd = recv_cred(control_fd);
223         if (fd < 0) {
224             return -1;
225         }
226         if (fds_in_use < fd_table_size) {
227             fd_table[fds_in_use].fd = fd;
228             fd_table[fds_in_use].tm = t;
229             fds_in_use++;
230         } else {
231             time_t older = t;
232
233             for (i = 0; i != fds_in_use; i++) {
234                 if (older <= fd_table[i].tm) {
235                     older = fd_table[i].tm;
236                     l = i;
237                 }
238             }
239             close(fd_table[l].fd);
240             fd_table[l].fd = fd;
241             fd_table[l].tm = t;
242         }
243         return 0;
244     }
245
246     for (i = 0; i != fds_in_use; i++) {
247         if (FD_ISSET(fd_table[i].fd, &readfds)) {
248             fd_table[i].tm = t;
249             return fd_table[i].fd;
250         }
251     }
252     /* We should never get here */
253     return 0;
254 }
255
256 int comm_init(struct db_param *dbp, int ctrlfd, int clntfd)
257 {
258     int i;
259
260     fds_in_use = 0;
261     fd_table_size = dbp->fd_table_size;
262
263     if ((fd_table = malloc(fd_table_size * sizeof(struct connection))) == NULL) {
264         LOG(log_error, logtype_cnid, "Out of memory");
265         return -1;
266     }
267     for (i = 0; i != fd_table_size; i++)
268         fd_table[i].fd = -1;
269     /* from dup2 */
270     control_fd = ctrlfd;
271 #if 0
272     int b = 1;
273     /* this one dump core in recvmsg, great */
274     if ( setsockopt(control_fd, SOL_SOCKET, SO_PASSCRED, &b, sizeof (b)) < 0) {
275         LOG(log_error, logtype_cnid, "setsockopt SO_PASSCRED %s",  strerror(errno));
276         return -1;
277     }
278 #endif
279     /* push the first client fd */
280     fd_table[fds_in_use].fd = clntfd;
281     fds_in_use++;
282
283     return 0;
284 }
285
286 /* ------------
287    nbe of clients
288 */
289 int comm_nbe(void)
290 {
291     return fds_in_use;
292 }
293
294 /* ------------ */
295 int comm_rcv(struct cnid_dbd_rqst *rqst, time_t timeout, const sigset_t *sigmask, time_t *now)
296 {
297     char *nametmp;
298     int b;
299
300     if ((cur_fd = check_fd(timeout, sigmask, now)) < 0)
301         return -1;
302
303     if (!cur_fd)
304         return 0;
305
306     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got data on fd %u", cur_fd);
307
308     if (setnonblock(cur_fd, 1) != 0) {
309         LOG(log_error, logtype_cnid, "comm_rcv: setnonblock: %s", strerror(errno));
310         return -1;
311     }
312
313     nametmp = rqst->name;
314     if ((b = readt(cur_fd, rqst, sizeof(struct cnid_dbd_rqst), CNID_DBD_TIMEOUT))
315         != sizeof(struct cnid_dbd_rqst)) {
316         if (b)
317             LOG(log_error, logtype_cnid, "error reading message header: %s", strerror(errno));
318         invalidate_fd(cur_fd);
319         rqst->name = nametmp;
320         return 0;
321     }
322     rqst->name = nametmp;
323     if (rqst->namelen && readt(cur_fd, rqst->name, rqst->namelen, CNID_DBD_TIMEOUT)
324         != rqst->namelen) {
325         LOG(log_error, logtype_cnid, "error reading message name: %s", strerror(errno));
326         invalidate_fd(cur_fd);
327         return 0;
328     }
329     /* We set this to make life easier for logging. None of the other stuff
330        needs zero terminated strings. */
331     rqst->name[rqst->namelen] = '\0';
332
333     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got %u bytes", b + rqst->namelen);
334
335     return 1;
336 }
337
338 /* ------------ */
339 #define USE_WRITEV
340 int comm_snd(struct cnid_dbd_rply *rply)
341 {
342 #ifdef USE_WRITEV
343     struct iovec iov[2];
344     size_t towrite;
345 #endif
346
347     if (!rply->namelen) {
348         if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
349             LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
350             invalidate_fd(cur_fd);
351             return 0;
352         }
353         return 1;
354     }
355 #ifdef USE_WRITEV
356
357     iov[0].iov_base = rply;
358     iov[0].iov_len = sizeof(struct cnid_dbd_rply);
359     iov[1].iov_base = rply->name;
360     iov[1].iov_len = rply->namelen;
361     towrite = sizeof(struct cnid_dbd_rply) +rply->namelen;
362
363     if (writev(cur_fd, iov, 2) != towrite) {
364         LOG(log_error, logtype_cnid, "error writing message : %s", strerror(errno));
365         invalidate_fd(cur_fd);
366         return 0;
367     }
368 #else
369     if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
370         LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
371         invalidate_fd(cur_fd);
372         return 0;
373     }
374     if (write(cur_fd, rply->name, rply->namelen) != rply->namelen) {
375         LOG(log_error, logtype_cnid, "error writing message name: %s", strerror(errno));
376         invalidate_fd(cur_fd);
377         return 0;
378     }
379 #endif
380     return 1;
381 }
382
383