patch-2.1.56 linux/fs/smbfs/sock.c
Next file: linux/fs/sysv/fsync.c
Previous file: linux/fs/smbfs/proc.c
Back to the patch index
Back to the overall index
- Lines: 605
- Date:
Sun Sep 14 15:14:56 1997
- Orig file:
v2.1.55/linux/fs/smbfs/sock.c
- Orig date:
Thu Sep 11 09:02:24 1997
diff -u --recursive --new-file v2.1.55/linux/fs/smbfs/sock.c linux/fs/smbfs/sock.c
@@ -24,10 +24,13 @@
#include <asm/uaccess.h>
+#define SMBFS_PARANOIA 1
+/* #define SMBFS_DEBUG_VERBOSE 1 */
+
#define _S(nr) (1<<((nr)-1))
static int
-_recvfrom(struct socket *sock, unsigned char *ubuf, int size,
+_recvfrom(struct socket *socket, unsigned char *ubuf, int size,
unsigned flags)
{
struct iovec iov;
@@ -43,14 +46,14 @@
iov.iov_len = size;
memset(&scm, 0,sizeof(scm));
- size=sock->ops->recvmsg(sock, &msg, size, flags, &scm);
+ size=socket->ops->recvmsg(socket, &msg, size, flags, &scm);
if(size>=0)
- scm_recv(sock,&msg,&scm,flags);
+ scm_recv(socket,&msg,&scm,flags);
return size;
}
static int
-_send(struct socket *sock, const void *buff, int len)
+_send(struct socket *socket, const void *buff, int len)
{
struct iovec iov;
struct msghdr msg;
@@ -69,163 +72,199 @@
msg.msg_flags = 0;
- err = scm_send(sock, &msg, &scm);
- if (err < 0)
- return err;
- err = sock->ops->sendmsg(sock, &msg, len, &scm);
- scm_destroy(&scm);
+ err = scm_send(socket, &msg, &scm);
+ if (err >= 0)
+ {
+ err = socket->ops->sendmsg(socket, &msg, len, &scm);
+ scm_destroy(&scm);
+ }
return err;
}
+/*
+ * N.B. What happens if we're in here when the socket closes??
+ */
static void
smb_data_callback(struct sock *sk, int len)
{
- struct socket *sock = sk->socket;
-
- if (!sk->dead)
- {
- unsigned char peek_buf[4];
- int result;
- unsigned long fs;
+ struct socket *socket = sk->socket;
+ unsigned char peek_buf[4];
+ int result;
+ unsigned long fs;
- fs = get_fs();
- set_fs(get_ds());
+ fs = get_fs();
+ set_fs(get_ds());
- result = _recvfrom(sock, (void *) peek_buf, 1,
+ while (1)
+ {
+ if (sk->dead)
+ {
+ printk("smb_data_callback: sock dead!\n");
+ return;
+ }
+ result = _recvfrom(socket, (void *) peek_buf, 1,
MSG_PEEK | MSG_DONTWAIT);
+ if (result == -EAGAIN)
+ break;
+ if (peek_buf[0] != 0x85)
+ break;
- while ((result != -EAGAIN) && (peek_buf[0] == 0x85))
- {
- /* got SESSION KEEP ALIVE */
- result = _recvfrom(sock, (void *) peek_buf, 4,
- MSG_DONTWAIT);
+ /* got SESSION KEEP ALIVE */
+ result = _recvfrom(socket, (void *) peek_buf, 4,
+ MSG_DONTWAIT);
- pr_debug("smb_data_callback: got SESSION KEEPALIVE\n");
+ pr_debug("smb_data_callback: got SESSION KEEPALIVE\n");
- if (result == -EAGAIN)
- {
- break;
- }
- result = _recvfrom(sock, (void *) peek_buf, 1,
- MSG_PEEK | MSG_DONTWAIT);
- }
- set_fs(fs);
+ if (result == -EAGAIN)
+ break;
+ }
+ set_fs(fs);
- if (result != -EAGAIN)
- {
- wake_up_interruptible(sk->sleep);
- }
+ if (result != -EAGAIN)
+ {
+ wake_up_interruptible(sk->sleep);
}
}
-int
-smb_catch_keepalive(struct smb_sb_info *server)
+static struct socket *
+server_sock(struct smb_sb_info *server)
{
struct file *file;
struct inode *inode;
- struct socket *sock;
- struct sock *sk;
- if ((server == NULL)
- || ((file = server->sock_file) == NULL)
- || ((inode = file->f_dentry->d_inode) == NULL)
- || (!S_ISSOCK(inode->i_mode)))
- {
- pr_debug("smb_catch_keepalive: did not get valid server!\n");
- server->data_ready = NULL;
- return -EINVAL;
- }
- sock = &(inode->u.socket_i);
+ if (server &&
+ (file = server->sock_file) &&
+ (inode = file->f_dentry->d_inode) &&
+ S_ISSOCK(inode->i_mode) &&
+ inode->u.socket_i.type == SOCK_STREAM)
+ return &(inode->u.socket_i);
+ return NULL;
+}
- if (sock->type != SOCK_STREAM)
+int
+smb_catch_keepalive(struct smb_sb_info *server)
+{
+ struct socket *socket;
+ struct sock *sk;
+ void *data_ready;
+ int error;
+
+ error = -EINVAL;
+ socket = server_sock(server);
+ if (!socket)
{
- pr_debug("smb_catch_keepalive: did not get SOCK_STREAM\n");
+ printk("smb_catch_keepalive: did not get valid server!\n");
server->data_ready = NULL;
- return -EINVAL;
+ goto out;
}
- sk = sock->sk;
+ sk = socket->sk;
if (sk == NULL)
{
pr_debug("smb_catch_keepalive: sk == NULL");
server->data_ready = NULL;
- return -EINVAL;
+ goto out;
}
pr_debug("smb_catch_keepalive.: sk->d_r = %x, server->d_r = %x\n",
(unsigned int) (sk->data_ready),
(unsigned int) (server->data_ready));
- if (sk->data_ready == smb_data_callback)
- {
+ /*
+ * Install the callback atomically to avoid races ...
+ */
+ data_ready = xchg(&sk->data_ready, smb_data_callback);
+ if (data_ready != smb_data_callback)
+ {
+ server->data_ready = data_ready;
+ error = 0;
+ } else
printk(KERN_ERR "smb_catch_keepalive: already done\n");
- return -EINVAL;
- }
- server->data_ready = sk->data_ready;
- sk->data_ready = smb_data_callback;
- return 0;
+out:
+ return error;
}
int
smb_dont_catch_keepalive(struct smb_sb_info *server)
{
- struct file *file;
- struct inode *inode;
- struct socket *sock;
+ struct socket *socket;
struct sock *sk;
+ void * data_ready;
+ int error;
- if ((server == NULL)
- || ((file = server->sock_file) == NULL)
- || ((inode = file->f_dentry->d_inode) == NULL)
- || (!S_ISSOCK(inode->i_mode)))
+ error = -EINVAL;
+ socket = server_sock(server);
+ if (!socket)
{
- printk("smb_dont_catch_keepalive: "
- "did not get valid server!\n");
- return -EINVAL;
+ printk("smb_dont_catch_keepalive: did not get valid server!\n");
+ goto out;
}
- sock = &(inode->u.socket_i);
-
- if (sock->type != SOCK_STREAM)
- {
- printk("smb_dont_catch_keepalive: did not get SOCK_STREAM\n");
- return -EINVAL;
- }
- sk = sock->sk;
+ sk = socket->sk;
if (sk == NULL)
{
printk("smb_dont_catch_keepalive: sk == NULL");
- return -EINVAL;
+ goto out;
}
+
+ /* Is this really an error?? */
if (server->data_ready == NULL)
{
printk("smb_dont_catch_keepalive: "
"server->data_ready == NULL\n");
- return -EINVAL;
- }
- if (sk->data_ready != smb_data_callback)
- {
- printk("smb_dont_catch_keepalive: "
- "sk->data_callback != smb_data_callback\n");
- return -EINVAL;
+ goto out;
}
pr_debug("smb_dont_catch_keepalive: sk->d_r = %x, server->d_r = %x\n",
(unsigned int) (sk->data_ready),
(unsigned int) (server->data_ready));
- sk->data_ready = server->data_ready;
+ /*
+ * Restore the original callback atomically to avoid races ...
+ */
+ data_ready = xchg(&sk->data_ready, server->data_ready);
server->data_ready = NULL;
- return 0;
+ if (data_ready != smb_data_callback)
+ {
+ printk("smb_dont_catch_keepalive: "
+ "sk->data_callback != smb_data_callback\n");
+ }
+ error = 0;
+out:
+ return error;
+}
+
+/*
+ * Called with the server locked.
+ */
+void
+smb_close_socket(struct smb_sb_info *server)
+{
+ struct file * file = server->sock_file;
+
+ if (file)
+ {
+ struct socket * socket = server_sock(server);
+
+ printk("smb_close_socket: closing socket %p\n", socket);
+ /*
+ * We need a way to check for tasks running the callback!
+ */
+ if (socket->sk->data_ready == smb_data_callback)
+ printk("smb_close_socket: still catching keepalives!\n");
+
+ server->sock_file = NULL;
+ close_fp(file);
+ }
}
static int
-smb_send_raw(struct socket *sock, unsigned char *source, int length)
+smb_send_raw(struct socket *socket, unsigned char *source, int length)
{
int result;
int already_sent = 0;
while (already_sent < length)
{
- result = _send(sock,
+ result = _send(socket,
(void *) (source + already_sent),
length - already_sent);
@@ -245,14 +284,14 @@
}
static int
-smb_receive_raw(struct socket *sock, unsigned char *target, int length)
+smb_receive_raw(struct socket *socket, unsigned char *target, int length)
{
int result;
int already_read = 0;
while (already_read < length)
{
- result = _recvfrom(sock,
+ result = _recvfrom(socket,
(void *) (target + already_read),
length - already_read, 0);
@@ -272,7 +311,7 @@
}
static int
-smb_get_length(struct socket *sock, unsigned char *header)
+smb_get_length(struct socket *socket, unsigned char *header)
{
int result;
unsigned char peek_buf[4];
@@ -281,7 +320,7 @@
re_recv:
fs = get_fs();
set_fs(get_ds());
- result = smb_receive_raw(sock, peek_buf, 4);
+ result = smb_receive_raw(socket, peek_buf, 4);
set_fs(fs);
if (result < 0)
@@ -312,21 +351,6 @@
return smb_len(peek_buf);
}
-static struct socket *
-server_sock(struct smb_sb_info *server)
-{
- struct file *file;
- struct inode *inode;
-
- if (server == NULL)
- return NULL;
- if ((file = server->sock_file) == NULL)
- return NULL;
- if ((inode = file->f_dentry->d_inode) == NULL)
- return NULL;
- return &(inode->u.socket_i);
-}
-
/*
* smb_receive
* fs points to the correct segment
@@ -334,12 +358,12 @@
static int
smb_receive(struct smb_sb_info *server)
{
- struct socket *sock = server_sock(server);
+ struct socket *socket = server_sock(server);
int len;
int result;
unsigned char peek_buf[4];
- len = smb_get_length(sock, peek_buf);
+ len = smb_get_length(socket, peek_buf);
if (len < 0)
{
@@ -352,6 +376,7 @@
pr_debug("smb_receive: Increase packet size from %d to %d\n",
server->packet_size, len + 4);
smb_vfree(server->packet);
+ server->packet = 0;
server->packet_size = 0;
server->packet = smb_vmalloc(len + 4);
if (server->packet == NULL)
@@ -361,7 +386,7 @@
server->packet_size = len + 4;
}
memcpy(server->packet, peek_buf, 4);
- result = smb_receive_raw(sock, server->packet + 4, len);
+ result = smb_receive_raw(socket, server->packet + 4, len);
if (result < 0)
{
@@ -371,11 +396,10 @@
server->rcls = *(server->packet+9);
server->err = WVAL(server->packet, 11);
- if (server->rcls != 0)
- {
- pr_debug("smb_receive: rcls=%d, err=%d\n",
- server->rcls, server->err);
- }
+#ifdef SMBFS_DEBUG_VERBOSE
+if (server->rcls != 0)
+printk("smb_receive: rcls=%d, err=%d\n", server->rcls, server->err);
+#endif
return result;
}
@@ -469,6 +493,13 @@
total_data = WVAL(inbuf, smb_tdrcnt);
total_param = WVAL(inbuf, smb_tprcnt);
+#ifdef SMBFS_PARANOIA
+if ((data_len >= total_data || param_len >= total_param) &&
+ !(data_len >= total_data && param_len >= total_param))
+printk("smb_receive_trans2: dlen=%d, tdata=%d, plen=%d, tlen=%d\n",
+data_len, total_data, param_len, total_param);
+#endif
+ /* shouldn't this be an OR test? don't want to overrun */
if ((data_len >= total_data) && (param_len >= total_param))
{
break;
@@ -477,11 +508,9 @@
{
goto fail;
}
+ result = -EIO;
if (server->rcls != 0)
- {
- result = -EIO;
goto fail;
- }
}
*ldata = data_len;
*lparam = param_len;
@@ -496,32 +525,33 @@
return result;
}
+/*
+ * Called with the server locked
+ */
int
smb_request(struct smb_sb_info *server)
{
unsigned long old_mask;
unsigned long fs;
int len, result;
+ unsigned char *buffer;
- unsigned char *buffer = (server == NULL) ? NULL : server->packet;
+ result = -EBADF;
+ if (!server) /* this can't happen */
+ goto bad_no_server;
+
+ buffer = server->packet;
+ if (!buffer)
+ goto bad_no_packet;
- if (buffer == NULL)
- {
- pr_debug("smb_request: Bad server!\n");
- return -EBADF;
- }
+ result = -EIO;
if (server->state != CONN_VALID)
- {
- return -EIO;
- }
+ goto bad_no_conn;
+
if ((result = smb_dont_catch_keepalive(server)) != 0)
- {
- server->state = CONN_INVALID;
- smb_invalidate_inodes(server);
- return result;
- }
- len = smb_len(buffer) + 4;
+ goto bad_conn;
+ len = smb_len(buffer) + 4;
pr_debug("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
old_mask = current->blocked;
@@ -544,17 +574,31 @@
int result2 = smb_catch_keepalive(server);
if (result2 < 0)
{
+ printk("smb_request: catch keepalive failed\n");
result = result2;
}
}
if (result < 0)
- {
- server->state = CONN_INVALID;
- smb_invalidate_inodes(server);
- }
- pr_debug("smb_request: result = %d\n", result);
+ goto bad_conn;
+out:
+ pr_debug("smb_request: result = %d\n", result);
return result;
+
+bad_conn:
+ printk("smb_request: result %d, setting invalid\n", result);
+ server->state = CONN_INVALID;
+ smb_invalidate_inodes(server);
+ goto out;
+bad_no_server:
+ printk("smb_request: no server!\n");
+ goto out;
+bad_no_packet:
+ printk("smb_request: no packet!\n");
+ goto out;
+bad_no_conn:
+ printk("smb_request: connection %d not valid!\n", server->state);
+ goto out;
}
#define ROUND_UP(x) (((x)+3) & ~3)
@@ -629,11 +673,11 @@
iov[3].iov_len = ldata;
err = scm_send(sock, &msg, &scm);
- if (err < 0)
- return err;
-
- err = sock->ops->sendmsg(sock, &msg, packet_length, &scm);
- scm_destroy(&scm);
+ if (err >= 0)
+ {
+ err = sock->ops->sendmsg(sock, &msg, packet_length, &scm);
+ scm_destroy(&scm);
+ }
return err;
}
@@ -655,16 +699,19 @@
pr_debug("smb_trans2_request: com=%d, ld=%d, lp=%d\n",
trans2_command, ldata, lparam);
+ /*
+ * These are initialized in smb_request_ok, but not here??
+ */
+ server->rcls = 0;
+ server->err = 0;
+
+ result = -EIO;
if (server->state != CONN_VALID)
- {
- return -EIO;
- }
+ goto out;
+
if ((result = smb_dont_catch_keepalive(server)) != 0)
- {
- server->state = CONN_INVALID;
- smb_invalidate_inodes(server);
- return result;
- }
+ goto bad_conn;
+
old_mask = current->blocked;
current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
fs = get_fs();
@@ -691,11 +738,15 @@
}
}
if (result < 0)
- {
- server->state = CONN_INVALID;
- smb_invalidate_inodes(server);
- }
+ goto bad_conn;
pr_debug("smb_trans2_request: result = %d\n", result);
+out:
return result;
+
+bad_conn:
+ printk("smb_trans2_request: connection bad, setting invalid\n");
+ server->state = CONN_INVALID;
+ smb_invalidate_inodes(server);
+ goto out;
}
FUNET's LINUX-ADM group, linux-adm@nic.funet.fi
TCL-scripts by Sam Shen, slshen@lbl.gov