diff options
| author | ame <[email protected]> | 2026-02-15 04:08:16 -0600 |
|---|---|---|
| committer | ame <[email protected]> | 2026-02-15 04:08:16 -0600 |
| commit | db2611fcad18f73572dd1b344e4197536086be53 (patch) | |
| tree | 8d6df833110e57fa7d77753571acfda2ebb23f95 /src/net.c | |
| parent | 0a909a9dc5879e592d92c6eedeb59da8cf503392 (diff) | |
ssl server support, websocket upgrades, and net changes
Diffstat (limited to 'src/net.c')
| -rw-r--r-- | src/net.c | 223 |
1 files changed, 112 insertions, 111 deletions
@@ -3,6 +3,7 @@ #include "net/lua.h"
#include "net/luai.h"
#include "types/str.h"
+#include "net/websocket.h"
#include <fcntl.h>
@@ -28,6 +29,7 @@ void ssl_init(){ if(has_ssl_init == 0){
has_ssl_init = 1;
SSL_library_init();
+ OpenSSL_add_all_algorithms();
SSL_load_error_strings();
}
}
@@ -171,96 +173,29 @@ void free_url(struct url u){ if(u.port != NULL) str_free(u.port);
}
-#define BUFFER_LEN 4096
-struct wss_data {
- SSL* ssl;
- SSL_CTX* ctx;
- int sock;
- str* buffer;
-};
-
-struct ws_frame_info {
- int fin;
- int rsv1;
- int rsv2;
- int rsv3;
- int opcode;
- int mask;
- int payload;
-};
-
int i_ws_read(lua_State* L){
lua_pushstring(L, "_");
lua_gettable(L, 1);
- struct wss_data* data = lua_touserdata(L, -1);
- char buffer[BUFFER_LEN] = {0};
- int total_len = 0;
- int len;
-
- for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){
- total_len += len;
- if(total_len >= 2) break;
- }
-
- if(len < 0) luaI_error(L, len, "SSL_read error");
-
- uint64_t payload = 0;
- struct ws_frame_info frame_info = {.fin = (buffer[0] >> 7) & 1, .rsv1 = (buffer[0] >> 6) & 1,
- .rsv2 = (buffer[0] >> 5) & 1, .rsv3 = (buffer[0] >> 4) & 1, .opcode = buffer[0] & 0b1111,
- .mask = (buffer[1] >> 7) & 1, .payload = buffer[1] & 0b1111111};
- //printf("fin: %i\npayload: %i\n", frame_info.fin, frame_info.payload);
- memset(buffer, 0, total_len);
- total_len = 0;
-
- if(frame_info.payload <= 125) payload = frame_info.payload;
- else if(frame_info.payload == 126) {
- for(; (len = SSL_read(data->ssl, buffer + total_len, 2 - total_len)) > 0;){
- total_len += len;
- if(total_len >= 2) break;
- }
-
- if(len < 0) luaI_error(L, len, "SSL_read error");
-
- payload = (buffer[0] & 0xff) << 8 | (buffer[1] & 0xff);
- } else {
- for(; (len = SSL_read(data->ssl, buffer + total_len, 8 - total_len)) > 0;){
- total_len += len;
- if(total_len >= 8) break;
- }
-
- if(len < 0) luaI_error(L, -1, "SSL_read error");
-
-
- payload = ((uint64_t)buffer[0] & 0xff) << 56 | ((uint64_t)buffer[1] & 0xff) << 48 | ((uint64_t)buffer[2] & 0xff) << 40 |
- ((uint64_t)buffer[3] & 0xff) << 32 | (buffer[4] & 0xff) << 24 | (buffer[5] & 0xff) << 16 | (buffer[6] & 0xff) << 8 | (buffer[7] & 0xff);
- }
- //printf("final payload: %lu\n", payload);
-
- str* message = str_init("");
- memset(buffer, 0, BUFFER_LEN);
- for(; message->len != payload && (len = SSL_read(data->ssl, buffer, lesser(payload - message->len, BUFFER_LEN))) > 0;){
- str_pushl(message, buffer, len);
- memset(buffer, 0, len);
- }
-
- if(len < 0) {
- str_free(message);
- luaI_error(L, len, "SSL_read error");
+ struct net_data* data = lua_touserdata(L, -1);
+ struct ws_frame_info frame = {};
+ if(ws_read(data, &frame) == -1){
+ luaI_error(L, -1, "SSL_read error");
+ str_clear(data->buffer);
}
lua_newtable(L);
int idx = lua_gettop(L);
- luaI_tsetsl(L, idx, "content", message->c, message->len);
- luaI_tseti(L, idx, "opcode", frame_info.opcode);
+ luaI_tsetsl(L, idx, "content", data->buffer->c, data->buffer->len);
+ luaI_tseti(L, idx, "opcode", frame.opcode);
- str_free(message);
+ str_clear(data->buffer);
return 1;
}
int i_ws_write(lua_State* L){
lua_pushstring(L, "_");
lua_gettable(L, 1);
- struct wss_data* data = lua_touserdata(L, -1);
+ struct net_data* data = lua_touserdata(L, -1);
uint64_t clen;
const char* content = luaL_tolstring(L, 2, &clen);
@@ -289,7 +224,7 @@ int i_ws_write(lua_State* L){ int i_ws_close(lua_State* L){
lua_pushstring(L, "_");
lua_gettable(L, 1);
- struct wss_data* data = lua_touserdata(L, -1);
+ struct net_data* data = lua_touserdata(L, -1);
if(data != NULL && data->buffer != NULL){
str_free(data->buffer);
@@ -310,7 +245,7 @@ int i_ws_close(lua_State* L){ return 0;
}
-
+#define BUFFER_LEN 4096
int l_wss(lua_State* L){
uint64_t len = 0;
@@ -331,8 +266,8 @@ int l_wss(lua_State* L){ SSL* ssl = ssl_connect(ctx, sock, awa.domain->c);
char* request = calloc(512, sizeof * request);
- sprintf(request, "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n"\
- "Sec-Websocket-Key: aWxvdmVsb3ZlbG92ZXlvdQ==\r\nSec-Websocket-Version: 13\r\n\r\n", path, awa.domain->c);
+ sprintf(request, "GET %s HTTP/1.1\r\nhost: %s\r\nconnection: upgrade\r\nupgrade: websocket\r\n"\
+ "sec-websocket-key: aWxvdmVsb3ZlbG92ZXlvdQ==\r\nsec-websocket-version: 13\r\n\r\n", path, awa.domain->c);
int s = SSL_write(ssl, request, strlen(request));
free_url(awa);
@@ -360,7 +295,7 @@ int l_wss(lua_State* L){ luaI_error(L, -1, "error with header formating");
}
- struct wss_data *data = malloc(sizeof * data);
+ struct net_data *data = malloc(sizeof * data);
data->ssl = ssl;
data->ctx = ctx;
data->sock = sock;
@@ -537,7 +472,7 @@ int _request(lua_State* L, struct request_state* state){ lua_newtable(L);
int header_idx = lua_gettop(L);
- luaI_tsets(L, header_idx, "User-Agent", "lullaby/"MAJOR_VERSION);
+ luaI_tsets(L, header_idx, "user-agent", "lullaby/"MAJOR_VERSION);
if(params >= 3){
lua_pushvalue(L, header_idx);
@@ -564,7 +499,7 @@ int _request(lua_State* L, struct request_state* state){ //char* req = "GET / HTTP/1.1\nHost: amyy.cc\nConnection: Close\n\n";
char* request = calloc(cont_len + header->len + strlen(host) + strlen(path) + 512, sizeof * request);
- sprintf(request, "%s %s HTTP/1.1\r\nHost: %s\r\nConnection: Close%s\r\n\r\n%s", action, path, host, header->c, cont);
+ sprintf(request, "%s %s HTTP/1.1\r\nhost: %s\r\nconnection: close%s\r\n\r\n%s", action, path, host, header->c, cont);
if(awa.path != NULL) str_free(awa.path);
if(awa.domain != NULL) str_free(awa.domain);
@@ -610,16 +545,16 @@ int _request(lua_State* L, struct request_state* state){ luaI_tsets(L, idx, (owo->P[i].key)->c, ((str*)owo->P[i].value)->c);
}
//done out of pure laziness, parse_header was meant for requests but works fine for responses, change this later
- luaI_treplk(L, idx, "Path", "code");
- luaI_treplk(L, idx, "Request", "version");
- luaI_treplk(L, idx, "Version", "code-name");
+ luaI_treplk(L, idx, "path", "code");
+ luaI_treplk(L, idx, "request", "version");
+ luaI_treplk(L, idx, "version", "code-name");
lua_pushstring(L, "code");
lua_gettable(L, idx);
int code = atoi(lua_tostring(L, -1));
luaI_tseti(L, idx, "code", code);
- void* encoding = parray_get(owo, "Transfer-Encoding");
+ void* encoding = parray_get(owo, "transfer-encoding");
//struct _srequest_state *read_state = calloc(sizeof * read_state, 1);
//read_state->ctx = state->ctx;
@@ -710,8 +645,9 @@ void path_parse(struct net_path_t* path, str* raw){ _Atomic size_t threads = 0;
void* handle_client(void *_arg){
+ signal(SIGPIPE, SIG_IGN);
thread_arg_struct* args = (thread_arg_struct*)_arg;
- int client_fd = args->fd;
+ struct net_data* ctx = args->ctx;
char* buffer;
int header_eof = -1;
lua_State* L = args->L;
@@ -719,7 +655,7 @@ void* handle_client(void *_arg){ char* header = NULL;
- int64_t bite = recv_header(client_fd, &buffer, &header);
+ int64_t bite = recv_header(ctx, &buffer, &header);
header_eof = header - buffer;
if(bite > 0){
@@ -728,13 +664,13 @@ void* handle_client(void *_arg){ //checks for a valid header
int val = parse_header(buffer, header_eof + 2, &table);
- if(val == -2) net_error(client_fd, 414);
+ if(val == -2) net_error(ctx, 414);
if(val >= 0){
- str* path = (str*)parray_get(table, "Path");
- str* sR = (str*)parray_get(table, "Request");
- str* sT = (str*)parray_get(table, "Content-Type");
- str* sC = (str*)parray_get(table, "Cookie");
+ str* path = (str*)parray_get(table, "path");
+ str* sR = (str*)parray_get(table, "request");
+ str* sT = (str*)parray_get(table, "content-type");
+ str* sC = (str*)parray_get(table, "cookie");
struct file_parse* file_cont = calloc(1, sizeof * file_cont);
lua_newtable(L);
@@ -755,7 +691,7 @@ void* handle_client(void *_arg){ parray_t* v = NULL;
if(decoded_err == 1 || args->paths == NULL){
- net_error(client_fd, 400);
+ net_error(ctx, 400);
} else {
str_push(aa, decoded_path->c);
@@ -787,7 +723,7 @@ void* handle_client(void *_arg){ luaI_tsetv(L, req_idx, "cookies", lcookie);
parray_clear(cookie, STR);
- parray_remove(table, "Cookie", STR);
+ parray_remove(table, "cookie", STR);
}
if(parsed_path.query->len != 0){
@@ -805,7 +741,7 @@ void* handle_client(void *_arg){ int ld = lua_gettop(L);
luaI_tsetv(L, req_idx, "_data", ld);
luaI_tsetv(L, req_idx, "files", files_idx);
- luaI_tsetv(L, req_idx, "Body", body_idx);
+ luaI_tsetv(L, req_idx, "body", body_idx);
for(int i = 0; i != table->len; i+=1){
//printf("'%s' :: '%s'\n",table[i]->c, table[i+1]->c);
@@ -817,29 +753,34 @@ void* handle_client(void *_arg){ luaI_tsets(L, req_idx, "rawpath", path->c);
if(bite == -1){
- client_fd = -2;
+ args->ctx->sock = -2;
}
luaI_tseti(L, req_idx, "_bytes", bite - header_eof - 4);
- luaI_tseti(L, req_idx, "client_fd", client_fd);
+ //luaI_tseti(L, req_idx, "client_fd", client_fd);
luaI_tsetcf(L, req_idx, "roll", l_roll);
+ luaI_tsetv(L, res_idx, "req", req_idx);
+ luaI_tsetv(L, req_idx, "res", res_idx);
+
//functions
luaI_tsetcf(L, res_idx, "send", l_send);
luaI_tsetcf(L, res_idx, "sendfile", l_sendfile);
luaI_tsetcf(L, res_idx, "write", l_write);
luaI_tsetcf(L, res_idx, "close", l_close);
luaI_tsetcf(L, res_idx, "stop", l_stop);
+ luaI_tsetcf(L, res_idx, "upgrade", l_connection_upgrade);
//values
- luaI_tseti(L, res_idx, "client_fd", client_fd);
+ //luaI_tseti(L, res_idx, "client_fd", client_fd);
+ luaI_tsetlud(L, res_idx, "_", ctx);
luaI_tsets(L, res_idx, "_request", sR->c);
//header table
lua_newtable(L);
int header_idx = lua_gettop(L);
- luaI_tseti(L, header_idx, "Code", 200);
+ luaI_tseti(L, header_idx, "code", 200);
luaI_tsetv(L, res_idx, "header", header_idx);
@@ -885,7 +826,7 @@ void* handle_client(void *_arg){ if(lua_pcall(L, 2, 0, errtraceback_idx) != 0){
fprintf(stderr, "(net thread) %s\n", lua_tostring(L, -1));
//send an error message if send has not been called
- if(client_fd >= 0) net_error(client_fd, 500);
+ if(args->ctx->sock >= 0) net_error(ctx, 500);
goto net_end;
}
@@ -904,9 +845,9 @@ net_end: larray_clear(params);
parray_lclear(owo); //dont free the rest
- lua_pushstring(L, "client_fd");
- lua_gettable(L, res_idx);
- client_fd = luaL_checkinteger(L, -1);
+ //lua_pushstring(L, "client_fd");
+ //lua_gettable(L, res_idx);
+ //client_fd = luaL_checkinteger(L, -1);
}
@@ -922,9 +863,12 @@ net_end: parray_clear(table, STR);
}
- if(client_fd != -1){
- shutdown(client_fd, 2);
- closesocket(client_fd);
+ if(args->ctx->sock != -1){
+ if(args->ctx->ssl != NULL) SSL_shutdown(args->ctx->ssl);
+ else {
+ shutdown(args->ctx->sock, 2);
+ closesocket(args->ctx->sock);
+ }
}
free(args);
@@ -941,12 +885,23 @@ int clean_lullaby_net(lua_State* L){ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* state){
parse_mimetypes();
+ if(state->ssl) ssl_init();
//need these on windows for sockets (stupid)
#ifdef _WIN32
WSADATA Data;
WSAStartup(MAKEWORD(2, 2), &Data);
#endif
+ SSL_CTX* server_ctx;
+ if(state->ssl){
+ if(!(server_ctx = SSL_CTX_new(TLS_server_method())))
+ luaI_error(L, -7, "SSL_CTX_new error");
+ luaI_assert(L, SSL_CTX_use_certificate_file(server_ctx, state->ssl_crt->c, SSL_FILETYPE_PEM) > 0)
+ luaI_assert(L, SSL_CTX_use_PrivateKey_file(server_ctx, state->ssl_key->c, SSL_FILETYPE_PEM) > 0)
+ if (SSL_CTX_check_private_key(server_ctx) == -1)
+ luaI_error(L, -8, "key does not match crt");
+ }
+
int server_fd;
struct sockaddr_in server_addr;
struct pollfd fds[2];
@@ -1008,7 +963,8 @@ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* if(threads >= max_con){
//deny request
- net_error(*client_fd, 503);
+#warning "need a better way to do this with ssl support"
+ //net_error(*client_fd, 503);
close(*client_fd);
free(client_fd);
@@ -1016,8 +972,31 @@ int start_serv(lua_State* L, int port, parray_t* paths, struct net_server_state* }
thread_arg_struct* args = malloc(sizeof * args);
-
- args->fd = *client_fd;
+ args->ctx = malloc(sizeof * args->ctx);
+
+ args->ctx->ssl = NULL;
+ if(state->ssl){
+ args->ctx->ssl = SSL_new(server_ctx);
+ if(!args->ctx->ssl){
+ fprintf(stderr, "SSL_new fail\n");
+ close(*client_fd);
+ free(client_fd);
+ free(args);
+ continue;
+ }
+ SSL_set_fd(args->ctx->ssl, *client_fd);
+
+ int f;
+ if((f = SSL_accept(args->ctx->ssl)) <= 0){
+ fprintf(stderr, "SSL_accept fail %i\n", f);
+ close(*client_fd);
+ free(client_fd);
+ SSL_free(args->ctx->ssl);
+ free(args);
+ continue;
+ }
+ }
+ args->ctx->sock = *client_fd;
args->port = port;
args->cli = client_addr;
args->L = luaL_newstate();
@@ -1050,6 +1029,7 @@ net_end: close(server_fd);
close(efd);
free(state);
+ SSL_CTX_free(server_ctx);
for(int i = 0; i != paths->len; i++){
struct sarray_t* path = paths->P[i].value;
@@ -1157,6 +1137,9 @@ int l_listen(lua_State* L){ struct net_server_state *state = malloc(sizeof * state);
state->event_fd = -1;
+ state->ssl = 0;
+ state->ssl_crt = str_init("");
+ state->ssl_key = str_init("");
int port = luaL_checkinteger(L, 2);
@@ -1173,6 +1156,8 @@ int l_listen(lua_State* L){ luaI_tsetcf(L, mt, "PATCH", l_PATCHq);
luaI_tsetcf(L, mt, "all", l_allq);
+ luaI_tsettab(L, mt, "ssl");
+
luaI_tsetcf(L, mt, "close", l_net_close);
luaI_tsetv(L, mt, "port", 2);
@@ -1187,6 +1172,22 @@ int l_listen(lua_State* L){ if(state->event_fd == -2) luaI_error(L, -2, "closed");
+ lua_getfield(L, mt, "ssl");
+ int ssl = lua_gettop(L);
+ if(!lua_isnil(L, -1) && lua_type(L, ssl) == LUA_TTABLE){
+ lua_getfield(L, ssl, "key");
+ if(!lua_isnil(L, -1)){
+ str_push(state->ssl_key, luaL_checkstring(L, -1));
+ lua_getfield(L, ssl, "crt");
+ if(!lua_isnil(L, -1)){
+ str_push(state->ssl_crt, luaL_checkstring(L, -1));
+ lua_pop(L, 2);
+ state->ssl = 1;
+ }
+ }
+ }
+ lua_pop(L, 1);
+
return start_serv(L, port, paths, state);
;
}
|
