patch-2.1.56 linux/fs/smbfs/sock.c

Next file: linux/fs/sysv/fsync.c
Previous file: linux/fs/smbfs/proc.c
Back to the patch index
Back to the overall index

diff -u --recursive --new-file v2.1.55/linux/fs/smbfs/sock.c linux/fs/smbfs/sock.c
@@ -24,10 +24,13 @@
 
 #include <asm/uaccess.h>
 
+#define SMBFS_PARANOIA 1
+/* #define SMBFS_DEBUG_VERBOSE 1 */
+
 #define _S(nr) (1<<((nr)-1))
 
 static int
-_recvfrom(struct socket *sock, unsigned char *ubuf, int size,
+_recvfrom(struct socket *socket, unsigned char *ubuf, int size,
 	  unsigned flags)
 {
 	struct iovec iov;
@@ -43,14 +46,14 @@
 	iov.iov_len = size;
 	
 	memset(&scm, 0,sizeof(scm));
-	size=sock->ops->recvmsg(sock, &msg, size, flags, &scm);
+	size=socket->ops->recvmsg(socket, &msg, size, flags, &scm);
 	if(size>=0)
-		scm_recv(sock,&msg,&scm,flags);
+		scm_recv(socket,&msg,&scm,flags);
 	return size;
 }
 
 static int
-_send(struct socket *sock, const void *buff, int len)
+_send(struct socket *socket, const void *buff, int len)
 {
 	struct iovec iov;
 	struct msghdr msg;
@@ -69,163 +72,199 @@
 
 	msg.msg_flags = 0;
 
-	err = scm_send(sock, &msg, &scm);
-        if (err < 0)
-                return err;
-	err = sock->ops->sendmsg(sock, &msg, len, &scm);
-	scm_destroy(&scm);
+	err = scm_send(socket, &msg, &scm);
+        if (err >= 0)
+	{
+		err = socket->ops->sendmsg(socket, &msg, len, &scm);
+		scm_destroy(&scm);
+	}
 	return err;
 }
 
+/*
+ * N.B. What happens if we're in here when the socket closes??
+ */
 static void
 smb_data_callback(struct sock *sk, int len)
 {
-	struct socket *sock = sk->socket;
-
-	if (!sk->dead)
-	{
-		unsigned char peek_buf[4];
-		int result;
-		unsigned long fs;
+	struct socket *socket = sk->socket;
+	unsigned char peek_buf[4];
+	int result;
+	unsigned long fs;
 
-		fs = get_fs();
-		set_fs(get_ds());
+	fs = get_fs();
+	set_fs(get_ds());
 
-		result = _recvfrom(sock, (void *) peek_buf, 1,
+	while (1)
+	{
+		if (sk->dead)
+		{
+			printk("smb_data_callback: sock dead!\n");
+			return;
+		}
+		result = _recvfrom(socket, (void *) peek_buf, 1,
 				   MSG_PEEK | MSG_DONTWAIT);
+		if (result == -EAGAIN)
+			break;
+		if (peek_buf[0] != 0x85)
+			break;
 
-		while ((result != -EAGAIN) && (peek_buf[0] == 0x85))
-		{
-			/* got SESSION KEEP ALIVE */
-			result = _recvfrom(sock, (void *) peek_buf, 4,
-					   MSG_DONTWAIT);
+		/* got SESSION KEEP ALIVE */
+		result = _recvfrom(socket, (void *) peek_buf, 4,
+				   MSG_DONTWAIT);
 
-			pr_debug("smb_data_callback: got SESSION KEEPALIVE\n");
+		pr_debug("smb_data_callback: got SESSION KEEPALIVE\n");
 
-			if (result == -EAGAIN)
-			{
-				break;
-			}
-			result = _recvfrom(sock, (void *) peek_buf, 1,
-					   MSG_PEEK | MSG_DONTWAIT);
-		}
-		set_fs(fs);
+		if (result == -EAGAIN)
+			break;
+	}
+	set_fs(fs);
 
-		if (result != -EAGAIN)
-		{
-			wake_up_interruptible(sk->sleep);
-		}
+	if (result != -EAGAIN)
+	{
+		wake_up_interruptible(sk->sleep);
 	}
 }
 
-int
-smb_catch_keepalive(struct smb_sb_info *server)
+static struct socket *
+server_sock(struct smb_sb_info *server)
 {
 	struct file *file;
 	struct inode *inode;
-	struct socket *sock;
-	struct sock *sk;
 
-	if ((server == NULL)
-	    || ((file = server->sock_file) == NULL)
-	    || ((inode = file->f_dentry->d_inode) == NULL)
-	    || (!S_ISSOCK(inode->i_mode)))
-	{
-		pr_debug("smb_catch_keepalive: did not get valid server!\n");
-		server->data_ready = NULL;
-		return -EINVAL;
-	}
-	sock = &(inode->u.socket_i);
+	if (server				&& 
+	    (file = server->sock_file)		&&
+	    (inode = file->f_dentry->d_inode)	&& 
+	    S_ISSOCK(inode->i_mode)		&& 
+	    inode->u.socket_i.type == SOCK_STREAM)
+		return &(inode->u.socket_i);
+	return NULL;
+}
 
-	if (sock->type != SOCK_STREAM)
+int
+smb_catch_keepalive(struct smb_sb_info *server)
+{
+	struct socket *socket;
+	struct sock *sk;
+	void *data_ready;
+	int error;
+
+	error = -EINVAL;
+	socket = server_sock(server);
+	if (!socket)
 	{
-		pr_debug("smb_catch_keepalive: did not get SOCK_STREAM\n");
+		printk("smb_catch_keepalive: did not get valid server!\n");
 		server->data_ready = NULL;
-		return -EINVAL;
+		goto out;
 	}
-	sk = sock->sk;
 
+	sk = socket->sk;
 	if (sk == NULL)
 	{
 		pr_debug("smb_catch_keepalive: sk == NULL");
 		server->data_ready = NULL;
-		return -EINVAL;
+		goto out;
 	}
 	pr_debug("smb_catch_keepalive.: sk->d_r = %x, server->d_r = %x\n",
 		 (unsigned int) (sk->data_ready),
 		 (unsigned int) (server->data_ready));
 
-	if (sk->data_ready == smb_data_callback)
-	{
+	/*
+	 * Install the callback atomically to avoid races ...
+	 */
+	data_ready = xchg(&sk->data_ready, smb_data_callback);
+	if (data_ready != smb_data_callback)
+	{
+		server->data_ready = data_ready;
+		error = 0;
+	} else
 		printk(KERN_ERR "smb_catch_keepalive: already done\n");
-		return -EINVAL;
-	}
-	server->data_ready = sk->data_ready;
-	sk->data_ready = smb_data_callback;
-	return 0;
+out:
+	return error;
 }
 
 int
 smb_dont_catch_keepalive(struct smb_sb_info *server)
 {
-	struct file *file;
-	struct inode *inode;
-	struct socket *sock;
+	struct socket *socket;
 	struct sock *sk;
+	void * data_ready;
+	int error;
 
-	if ((server == NULL)
-	    || ((file = server->sock_file) == NULL)
-	    || ((inode = file->f_dentry->d_inode) == NULL)
-	    || (!S_ISSOCK(inode->i_mode)))
+	error = -EINVAL;
+	socket = server_sock(server);
+	if (!socket)
 	{
-		printk("smb_dont_catch_keepalive: "
-		       "did not get valid server!\n");
-		return -EINVAL;
+		printk("smb_dont_catch_keepalive: did not get valid server!\n");
+		goto out;
 	}
-	sock = &(inode->u.socket_i);
-
-	if (sock->type != SOCK_STREAM)
-	{
-		printk("smb_dont_catch_keepalive: did not get SOCK_STREAM\n");
-		return -EINVAL;
-	}
-	sk = sock->sk;
 
+	sk = socket->sk;
 	if (sk == NULL)
 	{
 		printk("smb_dont_catch_keepalive: sk == NULL");
-		return -EINVAL;
+		goto out;
 	}
+
+	/* Is this really an error?? */
 	if (server->data_ready == NULL)
 	{
 		printk("smb_dont_catch_keepalive: "
 		       "server->data_ready == NULL\n");
-		return -EINVAL;
-	}
-	if (sk->data_ready != smb_data_callback)
-	{
-		printk("smb_dont_catch_keepalive: "
-		       "sk->data_callback != smb_data_callback\n");
-		return -EINVAL;
+		goto out;
 	}
 	pr_debug("smb_dont_catch_keepalive: sk->d_r = %x, server->d_r = %x\n",
 		 (unsigned int) (sk->data_ready),
 		 (unsigned int) (server->data_ready));
 
-	sk->data_ready = server->data_ready;
+	/*
+	 * Restore the original callback atomically to avoid races ...
+	 */
+	data_ready = xchg(&sk->data_ready, server->data_ready);
 	server->data_ready = NULL;
-	return 0;
+	if (data_ready != smb_data_callback)
+	{
+		printk("smb_dont_catch_keepalive: "
+		       "sk->data_callback != smb_data_callback\n");
+	}
+	error = 0;
+out:
+	return error;
+}
+
+/*
+ * Called with the server locked.
+ */
+void
+smb_close_socket(struct smb_sb_info *server)
+{
+	struct file * file = server->sock_file;
+
+	if (file)
+	{
+		struct socket * socket = server_sock(server);
+
+		printk("smb_close_socket: closing socket %p\n", socket);
+		/*
+		 * We need a way to check for tasks running the callback!
+		 */
+		if (socket->sk->data_ready == smb_data_callback)
+			printk("smb_close_socket: still catching keepalives!\n");
+
+		server->sock_file = NULL;
+		close_fp(file);
+	}
 }
 
 static int
-smb_send_raw(struct socket *sock, unsigned char *source, int length)
+smb_send_raw(struct socket *socket, unsigned char *source, int length)
 {
 	int result;
 	int already_sent = 0;
 
 	while (already_sent < length)
 	{
-		result = _send(sock,
+		result = _send(socket,
 			       (void *) (source + already_sent),
 			       length - already_sent);
 
@@ -245,14 +284,14 @@
 }
 
 static int
-smb_receive_raw(struct socket *sock, unsigned char *target, int length)
+smb_receive_raw(struct socket *socket, unsigned char *target, int length)
 {
 	int result;
 	int already_read = 0;
 
 	while (already_read < length)
 	{
-		result = _recvfrom(sock,
+		result = _recvfrom(socket,
 				   (void *) (target + already_read),
 				   length - already_read, 0);
 
@@ -272,7 +311,7 @@
 }
 
 static int
-smb_get_length(struct socket *sock, unsigned char *header)
+smb_get_length(struct socket *socket, unsigned char *header)
 {
 	int result;
 	unsigned char peek_buf[4];
@@ -281,7 +320,7 @@
       re_recv:
 	fs = get_fs();
 	set_fs(get_ds());
-	result = smb_receive_raw(sock, peek_buf, 4);
+	result = smb_receive_raw(socket, peek_buf, 4);
 	set_fs(fs);
 
 	if (result < 0)
@@ -312,21 +351,6 @@
 	return smb_len(peek_buf);
 }
 
-static struct socket *
-server_sock(struct smb_sb_info *server)
-{
-	struct file *file;
-	struct inode *inode;
-
-	if (server == NULL)
-		return NULL;
-	if ((file = server->sock_file) == NULL)
-		return NULL;
-	if ((inode = file->f_dentry->d_inode) == NULL)
-		return NULL;
-	return &(inode->u.socket_i);
-}
-
 /*
  * smb_receive
  * fs points to the correct segment
@@ -334,12 +358,12 @@
 static int
 smb_receive(struct smb_sb_info *server)
 {
-	struct socket *sock = server_sock(server);
+	struct socket *socket = server_sock(server);
 	int len;
 	int result;
 	unsigned char peek_buf[4];
 
-	len = smb_get_length(sock, peek_buf);
+	len = smb_get_length(socket, peek_buf);
 
 	if (len < 0)
 	{
@@ -352,6 +376,7 @@
 		pr_debug("smb_receive: Increase packet size from %d to %d\n",
 			server->packet_size, len + 4);
 		smb_vfree(server->packet);
+		server->packet = 0;
 		server->packet_size = 0;
 		server->packet = smb_vmalloc(len + 4);
 		if (server->packet == NULL)
@@ -361,7 +386,7 @@
 		server->packet_size = len + 4;
 	}
 	memcpy(server->packet, peek_buf, 4);
-	result = smb_receive_raw(sock, server->packet + 4, len);
+	result = smb_receive_raw(socket, server->packet + 4, len);
 
 	if (result < 0)
 	{
@@ -371,11 +396,10 @@
 	server->rcls = *(server->packet+9);
 	server->err = WVAL(server->packet, 11);
 
-	if (server->rcls != 0)
-	{
-		pr_debug("smb_receive: rcls=%d, err=%d\n",
-			 server->rcls, server->err);
-	}
+#ifdef SMBFS_DEBUG_VERBOSE
+if (server->rcls != 0)
+printk("smb_receive: rcls=%d, err=%d\n", server->rcls, server->err);
+#endif
 	return result;
 }
 
@@ -469,6 +493,13 @@
 		total_data = WVAL(inbuf, smb_tdrcnt);
 		total_param = WVAL(inbuf, smb_tprcnt);
 
+#ifdef SMBFS_PARANOIA
+if ((data_len >= total_data || param_len >= total_param) &&
+   !(data_len >= total_data && param_len >= total_param))
+printk("smb_receive_trans2: dlen=%d, tdata=%d, plen=%d, tlen=%d\n",
+data_len, total_data, param_len, total_param);
+#endif
+		/* shouldn't this be an OR test? don't want to overrun */
 		if ((data_len >= total_data) && (param_len >= total_param))
 		{
 			break;
@@ -477,11 +508,9 @@
 		{
 			goto fail;
 		}
+		result = -EIO;
 		if (server->rcls != 0)
-		{
-			result = -EIO;
 			goto fail;
-		}
 	}
 	*ldata = data_len;
 	*lparam = param_len;
@@ -496,32 +525,33 @@
 	return result;
 }
 
+/*
+ * Called with the server locked
+ */
 int
 smb_request(struct smb_sb_info *server)
 {
 	unsigned long old_mask;
 	unsigned long fs;
 	int len, result;
+	unsigned char *buffer;
 
-	unsigned char *buffer = (server == NULL) ? NULL : server->packet;
+	result = -EBADF;
+	if (!server) /* this can't happen */
+		goto bad_no_server;
+		
+	buffer = server->packet;
+	if (!buffer)
+		goto bad_no_packet;
 
-	if (buffer == NULL)
-	{
-		pr_debug("smb_request: Bad server!\n");
-		return -EBADF;
-	}
+	result = -EIO;
 	if (server->state != CONN_VALID)
-	{
-		return -EIO;
-	}
+		goto bad_no_conn;
+
 	if ((result = smb_dont_catch_keepalive(server)) != 0)
-	{
-		server->state = CONN_INVALID;
-		smb_invalidate_inodes(server);
-		return result;
-	}
-	len = smb_len(buffer) + 4;
+		goto bad_conn;
 
+	len = smb_len(buffer) + 4;
 	pr_debug("smb_request: len = %d cmd = 0x%X\n", len, buffer[8]);
 
 	old_mask = current->blocked;
@@ -544,17 +574,31 @@
 		int result2 = smb_catch_keepalive(server);
 		if (result2 < 0)
 		{
+			printk("smb_request: catch keepalive failed\n");
 			result = result2;
 		}
 	}
 	if (result < 0)
-	{
-		server->state = CONN_INVALID;
-		smb_invalidate_inodes(server);
-	}
-	pr_debug("smb_request: result = %d\n", result);
+		goto bad_conn;
 
+out:
+	pr_debug("smb_request: result = %d\n", result);
 	return result;
+	
+bad_conn:
+	printk("smb_request: result %d, setting invalid\n", result);
+	server->state = CONN_INVALID;
+	smb_invalidate_inodes(server);
+	goto out;		
+bad_no_server:
+	printk("smb_request: no server!\n");
+	goto out;
+bad_no_packet:
+	printk("smb_request: no packet!\n");
+	goto out;
+bad_no_conn:
+	printk("smb_request: connection %d not valid!\n", server->state);
+	goto out;
 }
 
 #define ROUND_UP(x) (((x)+3) & ~3)
@@ -629,11 +673,11 @@
 	iov[3].iov_len = ldata;
 
 	err = scm_send(sock, &msg, &scm);
-        if (err < 0)
-                return err;
-	
-	err = sock->ops->sendmsg(sock, &msg, packet_length, &scm);
-	scm_destroy(&scm);
+        if (err >= 0)
+	{
+		err = sock->ops->sendmsg(sock, &msg, packet_length, &scm);
+		scm_destroy(&scm);
+	}
 	return err;
 }
 
@@ -655,16 +699,19 @@
 	pr_debug("smb_trans2_request: com=%d, ld=%d, lp=%d\n",
 		 trans2_command, ldata, lparam);
 
+	/*
+	 * These are initialized in smb_request_ok, but not here??
+	 */
+	server->rcls = 0;
+	server->err = 0;
+
+	result = -EIO;
 	if (server->state != CONN_VALID)
-	{
-		return -EIO;
-	}
+		goto out;
+
 	if ((result = smb_dont_catch_keepalive(server)) != 0)
-	{
-		server->state = CONN_INVALID;
-		smb_invalidate_inodes(server);
-		return result;
-	}
+		goto bad_conn;
+
 	old_mask = current->blocked;
 	current->blocked |= ~(_S(SIGKILL) | _S(SIGSTOP));
 	fs = get_fs();
@@ -691,11 +738,15 @@
 		}
 	}
 	if (result < 0)
-	{
-		server->state = CONN_INVALID;
-		smb_invalidate_inodes(server);
-	}
+		goto bad_conn;
 	pr_debug("smb_trans2_request: result = %d\n", result);
 
+out:
 	return result;
+
+bad_conn:
+	printk("smb_trans2_request: connection bad, setting invalid\n");
+	server->state = CONN_INVALID;
+	smb_invalidate_inodes(server);
+	goto out;
 }

FUNET's LINUX-ADM group, linux-adm@nic.funet.fi
TCL-scripts by Sam Shen, slshen@lbl.gov