2 * Copyright (C) Joerg Lenneis 2003
3 * Copyright (C) Frank Lahm 2010
5 * All Rights Reserved. See COPYING.
17 #include <sys/param.h>
18 #include <sys/types.h>
21 #include <sys/socket.h>
22 #include <sys/select.h>
26 #include <atalk/logger.h>
27 #include <atalk/util.h>
28 #include <atalk/cnid_dbd_private.h>
34 /* Length of the space taken up by a padded control message of length len */
36 #define CMSG_SPACE(len) (__CMSG_ALIGN(sizeof(struct cmsghdr)) + __CMSG_ALIGN(len))
41 time_t tm; /* When respawned last */
45 static int control_fd;
47 static struct connection *fd_table;
48 static int fd_table_size;
49 static int fds_in_use = 0;
52 static void invalidate_fd(int fd)
58 for (i = 0; i != fds_in_use; i++)
59 if (fd_table[i].fd == fd)
62 assert(i < fds_in_use);
65 fd_table[i] = fd_table[fds_in_use];
66 fd_table[fds_in_use].fd = -1;
72 * non-blocking drop-in replacement for read with timeout using select
74 * @param socket (r) must be nonblocking !
75 * @param data (rw) buffer for the read data
76 * @param lenght (r) how many bytes to read
77 * @param timeout (r) number of seconds to try reading
79 * @returns number of bytes actually read or -1 on fatal error
81 static ssize_t readt(int socket, void *data, const size_t length, int timeout)
91 while (stored < length) {
92 len = read(socket, (u_int8_t *) data + stored, length - stored);
102 FD_SET(socket, &rfds);
103 while ((ret = select(socket + 1, &rfds, NULL, NULL, &tv)) < 1) {
106 LOG(log_warning, logtype_cnid, "select timeout 1s");
109 LOG(log_error, logtype_cnid, "select: %s", strerror(errno));
115 LOG(log_error, logtype_cnid, "read: %s", strerror(errno));
127 static int recv_cred(int fd)
132 struct cmsghdr *cmsgp = NULL;
133 char buf[CMSG_SPACE(sizeof(int))];
136 memset(&msgh,0,sizeof(msgh));
137 memset(buf,0,sizeof(buf));
139 msgh.msg_name = NULL;
140 msgh.msg_namelen = 0;
145 iov[0].iov_base = dbuf;
146 iov[0].iov_len = sizeof(dbuf);
148 msgh.msg_control = buf;
149 msgh.msg_controllen = sizeof(buf);
152 ret = recvmsg(fd ,&msgh,0);
153 } while ( ret == -1 && errno == EINTR );
159 for ( cmsgp = CMSG_FIRSTHDR(&msgh); cmsgp != NULL; cmsgp = CMSG_NXTHDR(&msgh,cmsgp) ) {
160 if ( cmsgp->cmsg_level == SOL_SOCKET && cmsgp->cmsg_type == SCM_RIGHTS ) {
161 return *(int *) CMSG_DATA(cmsgp);
165 if ( ret == sizeof (int) )
166 errno = *(int *)dbuf; /* Rcvd errno */
168 errno = ENOENT; /* Default errno */
174 * Check for client requests. We keep up to fd_table_size open descriptors in
175 * fd_table. If the table is full and we get a new descriptor via
176 * control_fd, we close a random decriptor in the table to make space. The
177 * affected client will automatically reconnect. For an EOF (descriptor is
178 * closed by the client, so a read here returns 0) comm_rcv will take care of
179 * things and clean up fd_table. The same happens for any read/write errors.
182 static int check_fd(time_t timeout, const sigset_t *sigmask, time_t *now)
189 int maxfd = control_fd;
193 FD_SET(control_fd, &readfds);
195 for (i = 0; i != fds_in_use; i++) {
196 FD_SET(fd_table[i].fd, &readfds);
197 if (maxfd < fd_table[i].fd)
198 maxfd = fd_table[i].fd;
203 if ((ret = pselect(maxfd + 1, &readfds, NULL, NULL, &tv, sigmask)) < 0) {
206 LOG(log_error, logtype_cnid, "error in select: %s",strerror(errno));
218 if (FD_ISSET(control_fd, &readfds)) {
221 fd = recv_cred(control_fd);
225 if (fds_in_use < fd_table_size) {
226 fd_table[fds_in_use].fd = fd;
227 fd_table[fds_in_use].tm = t;
232 for (i = 0; i != fds_in_use; i++) {
233 if (older <= fd_table[i].tm) {
234 older = fd_table[i].tm;
238 close(fd_table[l].fd);
245 for (i = 0; i != fds_in_use; i++) {
246 if (FD_ISSET(fd_table[i].fd, &readfds)) {
248 return fd_table[i].fd;
251 /* We should never get here */
255 int comm_init(struct db_param *dbp, int ctrlfd, int clntfd)
260 fd_table_size = dbp->fd_table_size;
262 if ((fd_table = malloc(fd_table_size * sizeof(struct connection))) == NULL) {
263 LOG(log_error, logtype_cnid, "Out of memory");
266 for (i = 0; i != fd_table_size; i++)
272 /* this one dump core in recvmsg, great */
273 if ( setsockopt(control_fd, SOL_SOCKET, SO_PASSCRED, &b, sizeof (b)) < 0) {
274 LOG(log_error, logtype_cnid, "setsockopt SO_PASSCRED %s", strerror(errno));
278 /* push the first client fd */
279 fd_table[fds_in_use].fd = clntfd;
294 int comm_rcv(struct cnid_dbd_rqst *rqst, time_t timeout, const sigset_t *sigmask, time_t *now)
299 if ((cur_fd = check_fd(timeout, sigmask, now)) < 0)
305 LOG(log_maxdebug, logtype_cnid, "comm_rcv: got data on fd %u", cur_fd);
307 if (setnonblock(cur_fd, 1) != 0) {
308 LOG(log_error, logtype_cnid, "comm_rcv: setnonblock: %s", strerror(errno));
312 nametmp = rqst->name;
313 if ((b = readt(cur_fd, rqst, sizeof(struct cnid_dbd_rqst), CNID_DBD_TIMEOUT))
314 != sizeof(struct cnid_dbd_rqst)) {
316 LOG(log_error, logtype_cnid, "error reading message header: %s", strerror(errno));
317 invalidate_fd(cur_fd);
318 rqst->name = nametmp;
321 rqst->name = nametmp;
322 if (rqst->namelen && readt(cur_fd, rqst->name, rqst->namelen, CNID_DBD_TIMEOUT)
324 LOG(log_error, logtype_cnid, "error reading message name: %s", strerror(errno));
325 invalidate_fd(cur_fd);
328 /* We set this to make life easier for logging. None of the other stuff
329 needs zero terminated strings. */
330 rqst->name[rqst->namelen] = '\0';
332 LOG(log_maxdebug, logtype_cnid, "comm_rcv: got %u bytes", b + rqst->namelen);
339 int comm_snd(struct cnid_dbd_rply *rply)
346 if (!rply->namelen) {
347 if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
348 LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
349 invalidate_fd(cur_fd);
356 iov[0].iov_base = rply;
357 iov[0].iov_len = sizeof(struct cnid_dbd_rply);
358 iov[1].iov_base = rply->name;
359 iov[1].iov_len = rply->namelen;
360 towrite = sizeof(struct cnid_dbd_rply) +rply->namelen;
362 if (writev(cur_fd, iov, 2) != towrite) {
363 LOG(log_error, logtype_cnid, "error writing message : %s", strerror(errno));
364 invalidate_fd(cur_fd);
368 if (write(cur_fd, rply, sizeof(struct cnid_dbd_rply)) != sizeof(struct cnid_dbd_rply)) {
369 LOG(log_error, logtype_cnid, "error writing message header: %s", strerror(errno));
370 invalidate_fd(cur_fd);
373 if (write(cur_fd, rply->name, rply->namelen) != rply->namelen) {
374 LOG(log_error, logtype_cnid, "error writing message name: %s", strerror(errno));
375 invalidate_fd(cur_fd);