]> arthur.barton.de Git - netatalk.git/blob - etc/cnid_dbd/comm.c
7cd8f5598c7cff49930ac147f799e7dea1d210bb
[netatalk.git] / etc / cnid_dbd / comm.c
1 /*
2  * Copyright (C) Joerg Lenneis 2003
3  * Copyright (C) Frank Lahm 2010
4  *
5  * All Rights Reserved.  See COPYING.
6  */
7
8 #ifdef HAVE_CONFIG_H
9 #include "config.h"
10 #endif
11
12 #include <atalk/standards.h>
13
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <errno.h>
18 #include <unistd.h>
19 #include <sys/param.h>
20 #include <sys/types.h>
21 #include <sys/time.h>
22 #include <sys/uio.h>
23 #include <sys/socket.h>
24 #include <sys/select.h>
25 #include <assert.h>
26 #include <time.h>
27
28 #include <atalk/logger.h>
29 #include <atalk/util.h>
30 #include <atalk/cnid_dbd_private.h>
31 #include <atalk/compat.h>
32
33 #include "db_param.h"
34 #include "usockfd.h"
35 #include "comm.h"
36
37 /* Length of the space taken up by a padded control message of length len */
38 #ifndef CMSG_SPACE
39 #define CMSG_SPACE(len) (__CMSG_ALIGN(sizeof(struct cmsghdr)) + __CMSG_ALIGN(len))
40 #endif
41
42
43 struct connection {
44     time_t tm;                    /* When respawned last */
45     int    fd;
46 };
47
48 static int   control_fd;
49 static int   cur_fd;
50 static struct connection *fd_table;
51 static int  fd_table_size;
52 static int  fds_in_use = 0;
53
54
55 static void invalidate_fd(int fd)
56 {
57     int i;
58
59     if (fd == control_fd)
60         return;
61     for (i = 0; i != fds_in_use; i++)
62         if (fd_table[i].fd == fd)
63             break;
64
65     assert(i < fds_in_use);
66
67     fds_in_use--;
68     fd_table[i] = fd_table[fds_in_use];
69     fd_table[fds_in_use].fd = -1;
70     close(fd);
71     return;
72 }
73
74
75 /*
76  *  Check for client requests. We keep up to fd_table_size open descriptors in
77  *  fd_table. If the table is full and we get a new descriptor via
78  *  control_fd, we close a random decriptor in the table to make space. The
79  *  affected client will automatically reconnect. For an EOF (descriptor is
80  *  closed by the client, so a read here returns 0) comm_rcv will take care of
81  *  things and clean up fd_table. The same happens for any read/write errors.
82  */
83
84 static int check_fd(time_t timeout, const sigset_t *sigmask, time_t *now)
85 {
86     int fd;
87     fd_set readfds;
88     struct timespec tv;
89     int ret;
90     int i;
91     int maxfd = control_fd;
92     time_t t;
93
94     FD_ZERO(&readfds);
95     FD_SET(control_fd, &readfds);
96
97     for (i = 0; i != fds_in_use; i++) {
98         FD_SET(fd_table[i].fd, &readfds);
99         if (maxfd < fd_table[i].fd)
100             maxfd = fd_table[i].fd;
101     }
102
103     tv.tv_nsec = 0;
104     tv.tv_sec  = timeout;
105     if ((ret = pselect(maxfd + 1, &readfds, NULL, NULL, &tv, sigmask)) < 0) {
106         if (errno == EINTR)
107             return 0;
108         LOG(log_error, logtype_cnid, "error in select: %s",strerror(errno));
109         return -1;
110     }
111
112     time(&t);
113     if (now)
114         *now = t;
115
116     if (!ret)
117         return 0;
118
119
120     if (FD_ISSET(control_fd, &readfds)) {
121         int    l = 0;
122
123         fd = recv_fd(control_fd, 0);
124         if (fd < 0) {
125             return -1;
126         }
127         if (fds_in_use < fd_table_size) {
128             fd_table[fds_in_use].fd = fd;
129             fd_table[fds_in_use].tm = t;
130             fds_in_use++;
131         } else {
132             time_t older = t;
133
134             for (i = 0; i != fds_in_use; i++) {
135                 if (older <= fd_table[i].tm) {
136                     older = fd_table[i].tm;
137                     l = i;
138                 }
139             }
140             close(fd_table[l].fd);
141             fd_table[l].fd = fd;
142             fd_table[l].tm = t;
143         }
144         return 0;
145     }
146
147     for (i = 0; i != fds_in_use; i++) {
148         if (FD_ISSET(fd_table[i].fd, &readfds)) {
149             fd_table[i].tm = t;
150             return fd_table[i].fd;
151         }
152     }
153     /* We should never get here */
154     return 0;
155 }
156
157 int comm_init(struct db_param *dbp, int ctrlfd, int clntfd)
158 {
159     int i;
160
161     fds_in_use = 0;
162     fd_table_size = dbp->fd_table_size;
163
164     if ((fd_table = malloc(fd_table_size * sizeof(struct connection))) == NULL) {
165         LOG(log_error, logtype_cnid, "Out of memory");
166         return -1;
167     }
168     for (i = 0; i != fd_table_size; i++)
169         fd_table[i].fd = -1;
170     /* from dup2 */
171     control_fd = ctrlfd;
172 #if 0
173     int b = 1;
174     /* this one dump core in recvmsg, great */
175     if ( setsockopt(control_fd, SOL_SOCKET, SO_PASSCRED, &b, sizeof (b)) < 0) {
176         LOG(log_error, logtype_cnid, "setsockopt SO_PASSCRED %s",  strerror(errno));
177         return -1;
178     }
179 #endif
180     /* push the first client fd */
181     fd_table[fds_in_use].fd = clntfd;
182     fds_in_use++;
183
184     return 0;
185 }
186
187 /* ------------
188    nbe of clients
189 */
190 int comm_nbe(void)
191 {
192     return fds_in_use;
193 }
194
195 /* ------------ */
196 int comm_rcv(struct cnid_dbd_rqst *rqst, time_t timeout, const sigset_t *sigmask, time_t *now)
197 {
198     char *nametmp;
199     int b;
200
201     if ((cur_fd = check_fd(timeout, sigmask, now)) < 0)
202         return -1;
203
204     if (!cur_fd)
205         return 0;
206
207     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got data on fd %u", cur_fd);
208
209     if (setnonblock(cur_fd, 1) != 0) {
210         LOG(log_error, logtype_cnid, "comm_rcv: setnonblock: %s", strerror(errno));
211         return -1;
212     }
213
214     nametmp = (char *)rqst->name;
215     if ((b = readt(cur_fd, rqst, sizeof(struct cnid_dbd_rqst), 1, CNID_DBD_TIMEOUT))
216         != sizeof(struct cnid_dbd_rqst)) {
217         if (b)
218             LOG(log_error, logtype_cnid, "error reading message header: %s", strerror(errno));
219         invalidate_fd(cur_fd);
220         rqst->name = nametmp;
221         return 0;
222     }
223     rqst->name = nametmp;
224     if (rqst->namelen && readt(cur_fd, (char *)rqst->name, rqst->namelen, 1, CNID_DBD_TIMEOUT)
225         != rqst->namelen) {
226         LOG(log_error, logtype_cnid, "error reading message name: %s", strerror(errno));
227         invalidate_fd(cur_fd);
228         return 0;
229     }
230     /* We set this to make life easier for logging. None of the other stuff
231        needs zero terminated strings. */
232     ((char *)(rqst->name))[rqst->namelen] = '\0';
233
234     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got %u bytes", b + rqst->namelen);
235
236     return 1;
237 }
238
239 /* ------------ */
240 #define USE_WRITEV
241 int comm_snd(struct cnid_dbd_rply *rply)
242 {
243 #ifdef USE_WRITEV
244     struct iovec iov[2];
245     size_t towrite;
246 #endif
247
248     if (!rply->namelen) {
249         if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
250             LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
251             invalidate_fd(cur_fd);
252             return 0;
253         }
254         return 1;
255     }
256 #ifdef USE_WRITEV
257
258     iov[0].iov_base = rply;
259     iov[0].iov_len = sizeof(struct cnid_dbd_rply);
260     iov[1].iov_base = rply->name;
261     iov[1].iov_len = rply->namelen;
262     towrite = sizeof(struct cnid_dbd_rply) +rply->namelen;
263
264     if (writev(cur_fd, iov, 2) != towrite) {
265         LOG(log_error, logtype_cnid, "error writing message : %s", strerror(errno));
266         invalidate_fd(cur_fd);
267         return 0;
268     }
269 #else
270     if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
271         LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
272         invalidate_fd(cur_fd);
273         return 0;
274     }
275     if (write(cur_fd, rply->name, rply->namelen) != rply->namelen) {
276         LOG(log_error, logtype_cnid, "error writing message name: %s", strerror(errno));
277         invalidate_fd(cur_fd);
278         return 0;
279     }
280 #endif
281     return 1;
282 }
283
284