diff options
| -rw-r--r-- | include/swaylock/swaylock.h | 3 | ||||
| -rw-r--r-- | swaylock/main.c | 13 | ||||
| -rw-r--r-- | swaylock/password.c | 44 | 
3 files changed, 34 insertions, 26 deletions
| diff --git a/include/swaylock/swaylock.h b/include/swaylock/swaylock.h index 173e8b12..ed9fea19 100644 --- a/include/swaylock/swaylock.h +++ b/include/swaylock/swaylock.h @@ -24,9 +24,8 @@ struct swaylock_args {  };  struct swaylock_password { -	size_t size;  	size_t len; -	char *buffer; +	char buffer[1024];  };  struct swaylock_state { diff --git a/swaylock/main.c b/swaylock/main.c index 4c6b44c6..200c1b5f 100644 --- a/swaylock/main.c +++ b/swaylock/main.c @@ -8,6 +8,7 @@  #include <stdio.h>  #include <stdlib.h>  #include <string.h> +#include <sys/mman.h>  #include <sys/stat.h>  #include <time.h>  #include <unistd.h> @@ -18,10 +19,15 @@  #include "background-image.h"  #include "pool-buffer.h"  #include "cairo.h" +#include "log.h"  #include "util.h"  #include "wlr-input-inhibitor-unstable-v1-client-protocol.h"  #include "wlr-layer-shell-unstable-v1-client-protocol.h" +void sway_terminate(int exit_code) { +	exit(exit_code); +} +  static void daemonize() {  	int fds[2];  	if (pipe(fds) != 0) { @@ -236,6 +242,13 @@ int main(int argc, char **argv) {  		}  	} +#ifdef __linux__ +	// Most non-linux platforms require root to mlock() +	if (mlock(state.password.buffer, sizeof(state.password.buffer)) != 0) { +		sway_abort("Unable to mlock() password memory."); +	} +#endif +  	wl_list_init(&state.surfaces);  	state.xkb.context = xkb_context_new(XKB_CONTEXT_NO_FLAGS);  	state.display = wl_display_connect(NULL); diff --git a/swaylock/password.c b/swaylock/password.c index 1839f991..c8df3de8 100644 --- a/swaylock/password.c +++ b/swaylock/password.c @@ -1,7 +1,9 @@ +#define _XOPEN_SOURCE 500  #include <assert.h>  #include <pwd.h>  #include <security/pam_appl.h>  #include <stdlib.h> +#include <string.h>  #include <unistd.h>  #include <wlr/util/log.h>  #include <xkbcommon/xkbcommon.h> @@ -20,7 +22,7 @@ static int function_conversation(int num_msg, const struct pam_message **msg,  		switch (msg[i]->msg_style) {  		case PAM_PROMPT_ECHO_OFF:  		case PAM_PROMPT_ECHO_ON: -			pam_reply[i].resp = pw->buffer; +			pam_reply[i].resp = strdup(pw->buffer); // PAM clears and frees this  			break;  		case PAM_ERROR_MSG:  		case PAM_TEXT_INFO: @@ -30,6 +32,16 @@ static int function_conversation(int num_msg, const struct pam_message **msg,  	return PAM_SUCCESS;  } +void clear_password_buffer(struct swaylock_password *pw) { +	// Use volatile keyword so so compiler can't optimize this out. +	volatile char *buffer = pw->buffer; +	volatile char zero = '\0'; +	for (size_t i = 0; i < sizeof(buffer); ++i) { +		buffer[i] = zero; +	} +	pw->len = 0; +} +  static bool attempt_password(struct swaylock_password *pw) {  	struct passwd *passwd = getpwuid(getuid());  	char *username = passwd->pw_name; @@ -38,6 +50,7 @@ static bool attempt_password(struct swaylock_password *pw) {  	};  	pam_handle_t *local_auth_handle = NULL;  	int pam_err; +	// TODO: only call pam_start once. keep the same handle the whole time  	if ((pam_err = pam_start("swaylock", username,  					&local_conversation, &local_auth_handle)) != PAM_SUCCESS) {  		wlr_log(L_ERROR, "PAM returned error %d", pam_err); @@ -46,18 +59,15 @@ static bool attempt_password(struct swaylock_password *pw) {  		wlr_log(L_ERROR, "pam_authenticate failed");  		goto fail;  	} +	// TODO: only call pam_end once we succeed at authing. refresh tokens beforehand  	if ((pam_err = pam_end(local_auth_handle, pam_err)) != PAM_SUCCESS) {  		wlr_log(L_ERROR, "pam_end failed");  		goto fail;  	} -	// PAM frees this -	pw->buffer = NULL; -	pw->len = pw->size = 0; +	clear_password_buffer(pw);  	return true;  fail: -	// PAM frees this -	pw->buffer = NULL; -	pw->len = pw->size = 0; +	clear_password_buffer(pw);  	return false;  } @@ -70,24 +80,10 @@ static bool backspace(struct swaylock_password *pw) {  }  static void append_ch(struct swaylock_password *pw, uint32_t codepoint) { -	if (!pw->buffer) { -		pw->size = 8; -		if (!(pw->buffer = malloc(pw->size))) { -			// TODO: Display error -			return; -		} -		pw->buffer[0] = 0; -	}  	size_t utf8_size = utf8_chsize(codepoint); -	if (pw->len + utf8_size + 1 >= pw->size) { -		size_t size = pw->size * 2; -		char *buffer = realloc(pw->buffer, size); -		if (!buffer) { -			// TODO: Display error -			return; -		} -		pw->size = size; -		pw->buffer = buffer; +	if (pw->len + utf8_size + 1 >= sizeof(pw->buffer)) { +		// TODO: Display error +		return;  	}  	utf8_encode(&pw->buffer[pw->len], codepoint);  	pw->buffer[pw->len + utf8_size] = 0; | 
