]> arthur.barton.de Git - netatalk.git/blob - etc/cnid_dbd/comm.c
Merge remote branch 'sf/master'
[netatalk.git] / etc / cnid_dbd / comm.c
1 /*
2  * Copyright (C) Joerg Lenneis 2003
3  * Copyright (C) Frank Lahm 2010
4  *
5  * All Rights Reserved.  See COPYING.
6  */
7
8 #ifdef HAVE_CONFIG_H
9 #include "config.h"
10 #endif
11
12 #ifndef _XOPEN_SOURCE
13 # define _XOPEN_SOURCE 600
14 #endif
15 #ifndef __EXTENSIONS__
16 # define __EXTENSIONS__
17 #endif
18 #ifndef _GNU_SOURCE
19 # define _GNU_SOURCE
20 #endif
21 #ifndef __BSD_VISIBLE
22 /* for u_short, u_char, etc. */
23 # define __BSD_VISIBLE 1
24 #endif
25
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <errno.h>
30 #include <unistd.h>
31 #include <sys/param.h>
32 #include <sys/types.h>
33 #include <sys/time.h>
34 #include <sys/uio.h>
35 #include <sys/socket.h>
36 #include <sys/select.h>
37 #include <assert.h>
38 #include <time.h>
39
40 #include <atalk/logger.h>
41 #include <atalk/util.h>
42 #include <atalk/cnid_dbd_private.h>
43
44 #include "db_param.h"
45 #include "usockfd.h"
46 #include "comm.h"
47
48 /* Length of the space taken up by a padded control message of length len */
49 #ifndef CMSG_SPACE
50 #define CMSG_SPACE(len) (__CMSG_ALIGN(sizeof(struct cmsghdr)) + __CMSG_ALIGN(len))
51 #endif
52
53
54 struct connection {
55     time_t tm;                    /* When respawned last */
56     int    fd;
57 };
58
59 static int   control_fd;
60 static int   cur_fd;
61 static struct connection *fd_table;
62 static int  fd_table_size;
63 static int  fds_in_use = 0;
64
65
66 static void invalidate_fd(int fd)
67 {
68     int i;
69
70     if (fd == control_fd)
71         return;
72     for (i = 0; i != fds_in_use; i++)
73         if (fd_table[i].fd == fd)
74             break;
75
76     assert(i < fds_in_use);
77
78     fds_in_use--;
79     fd_table[i] = fd_table[fds_in_use];
80     fd_table[fds_in_use].fd = -1;
81     close(fd);
82     return;
83 }
84
85
86 /*
87  *  Check for client requests. We keep up to fd_table_size open descriptors in
88  *  fd_table. If the table is full and we get a new descriptor via
89  *  control_fd, we close a random decriptor in the table to make space. The
90  *  affected client will automatically reconnect. For an EOF (descriptor is
91  *  closed by the client, so a read here returns 0) comm_rcv will take care of
92  *  things and clean up fd_table. The same happens for any read/write errors.
93  */
94
95 static int check_fd(time_t timeout, const sigset_t *sigmask, time_t *now)
96 {
97     int fd;
98     fd_set readfds;
99     struct timespec tv;
100     int ret;
101     int i;
102     int maxfd = control_fd;
103     time_t t;
104
105     FD_ZERO(&readfds);
106     FD_SET(control_fd, &readfds);
107
108     for (i = 0; i != fds_in_use; i++) {
109         FD_SET(fd_table[i].fd, &readfds);
110         if (maxfd < fd_table[i].fd)
111             maxfd = fd_table[i].fd;
112     }
113
114     tv.tv_nsec = 0;
115     tv.tv_sec  = timeout;
116     if ((ret = pselect(maxfd + 1, &readfds, NULL, NULL, &tv, sigmask)) < 0) {
117         if (errno == EINTR)
118             return 0;
119         LOG(log_error, logtype_cnid, "error in select: %s",strerror(errno));
120         return -1;
121     }
122
123     time(&t);
124     if (now)
125         *now = t;
126
127     if (!ret)
128         return 0;
129
130
131     if (FD_ISSET(control_fd, &readfds)) {
132         int    l = 0;
133
134         fd = recv_fd(control_fd, 0);
135         if (fd < 0) {
136             return -1;
137         }
138         if (fds_in_use < fd_table_size) {
139             fd_table[fds_in_use].fd = fd;
140             fd_table[fds_in_use].tm = t;
141             fds_in_use++;
142         } else {
143             time_t older = t;
144
145             for (i = 0; i != fds_in_use; i++) {
146                 if (older <= fd_table[i].tm) {
147                     older = fd_table[i].tm;
148                     l = i;
149                 }
150             }
151             close(fd_table[l].fd);
152             fd_table[l].fd = fd;
153             fd_table[l].tm = t;
154         }
155         return 0;
156     }
157
158     for (i = 0; i != fds_in_use; i++) {
159         if (FD_ISSET(fd_table[i].fd, &readfds)) {
160             fd_table[i].tm = t;
161             return fd_table[i].fd;
162         }
163     }
164     /* We should never get here */
165     return 0;
166 }
167
168 int comm_init(struct db_param *dbp, int ctrlfd, int clntfd)
169 {
170     int i;
171
172     fds_in_use = 0;
173     fd_table_size = dbp->fd_table_size;
174
175     if ((fd_table = malloc(fd_table_size * sizeof(struct connection))) == NULL) {
176         LOG(log_error, logtype_cnid, "Out of memory");
177         return -1;
178     }
179     for (i = 0; i != fd_table_size; i++)
180         fd_table[i].fd = -1;
181     /* from dup2 */
182     control_fd = ctrlfd;
183 #if 0
184     int b = 1;
185     /* this one dump core in recvmsg, great */
186     if ( setsockopt(control_fd, SOL_SOCKET, SO_PASSCRED, &b, sizeof (b)) < 0) {
187         LOG(log_error, logtype_cnid, "setsockopt SO_PASSCRED %s",  strerror(errno));
188         return -1;
189     }
190 #endif
191     /* push the first client fd */
192     fd_table[fds_in_use].fd = clntfd;
193     fds_in_use++;
194
195     return 0;
196 }
197
198 /* ------------
199    nbe of clients
200 */
201 int comm_nbe(void)
202 {
203     return fds_in_use;
204 }
205
206 /* ------------ */
207 int comm_rcv(struct cnid_dbd_rqst *rqst, time_t timeout, const sigset_t *sigmask, time_t *now)
208 {
209     char *nametmp;
210     int b;
211
212     if ((cur_fd = check_fd(timeout, sigmask, now)) < 0)
213         return -1;
214
215     if (!cur_fd)
216         return 0;
217
218     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got data on fd %u", cur_fd);
219
220     if (setnonblock(cur_fd, 1) != 0) {
221         LOG(log_error, logtype_cnid, "comm_rcv: setnonblock: %s", strerror(errno));
222         return -1;
223     }
224
225     nametmp = rqst->name;
226     if ((b = readt(cur_fd, rqst, sizeof(struct cnid_dbd_rqst), 1, CNID_DBD_TIMEOUT))
227         != sizeof(struct cnid_dbd_rqst)) {
228         if (b)
229             LOG(log_error, logtype_cnid, "error reading message header: %s", strerror(errno));
230         invalidate_fd(cur_fd);
231         rqst->name = nametmp;
232         return 0;
233     }
234     rqst->name = nametmp;
235     if (rqst->namelen && readt(cur_fd, rqst->name, rqst->namelen, 1, CNID_DBD_TIMEOUT)
236         != rqst->namelen) {
237         LOG(log_error, logtype_cnid, "error reading message name: %s", strerror(errno));
238         invalidate_fd(cur_fd);
239         return 0;
240     }
241     /* We set this to make life easier for logging. None of the other stuff
242        needs zero terminated strings. */
243     rqst->name[rqst->namelen] = '\0';
244
245     LOG(log_maxdebug, logtype_cnid, "comm_rcv: got %u bytes", b + rqst->namelen);
246
247     return 1;
248 }
249
250 /* ------------ */
251 #define USE_WRITEV
252 int comm_snd(struct cnid_dbd_rply *rply)
253 {
254 #ifdef USE_WRITEV
255     struct iovec iov[2];
256     size_t towrite;
257 #endif
258
259     if (!rply->namelen) {
260         if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
261             LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
262             invalidate_fd(cur_fd);
263             return 0;
264         }
265         return 1;
266     }
267 #ifdef USE_WRITEV
268
269     iov[0].iov_base = rply;
270     iov[0].iov_len = sizeof(struct cnid_dbd_rply);
271     iov[1].iov_base = rply->name;
272     iov[1].iov_len = rply->namelen;
273     towrite = sizeof(struct cnid_dbd_rply) +rply->namelen;
274
275     if (writev(cur_fd, iov, 2) != towrite) {
276         LOG(log_error, logtype_cnid, "error writing message : %s", strerror(errno));
277         invalidate_fd(cur_fd);
278         return 0;
279     }
280 #else
281     if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
282         LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
283         invalidate_fd(cur_fd);
284         return 0;
285     }
286     if (write(cur_fd, rply->name, rply->namelen) != rply->namelen) {
287         LOG(log_error, logtype_cnid, "error writing message name: %s", strerror(errno));
288         invalidate_fd(cur_fd);
289         return 0;
290     }
291 #endif
292     return 1;
293 }
294
295