Created
March 28, 2025 02:07
-
-
Save jlucaso1/f30286614a0bfc11a7d4f1f6efa400a3 to your computer and use it in GitHub Desktop.
A tcp and udp reverse proxy in zig
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const std = @import("std"); | |
| const net = std.net; | |
| const mem = std.mem; | |
| const fmt = std.fmt; | |
| const process = std.process; | |
| const log = std.log.scoped(.proxy); | |
| const posix = std.posix; // For UDP socket operations | |
| const Thread = std.Thread; | |
| const io = std.io; | |
| const Allocator = mem.Allocator; | |
| // Structure to hold parsed arguments | |
| const Args = struct { | |
| protocol: Protocol, | |
| frontend_addr: net.Address, | |
| backend_addr: net.Address, | |
| const Protocol = enum { tcp, udp }; | |
| }; | |
| // Helper to parse "host:port" strings into a net.Address | |
| fn parseAddr(allocator: Allocator, input: []const u8) !net.Address { | |
| const colon_index = mem.lastIndexOfScalar(u8, input, ':'); | |
| if (colon_index == null) { | |
| log.err("Address '{s}' must be in host:port format", .{input}); | |
| return error.InvalidAddressFormat; | |
| } | |
| const host_part = input[0..colon_index.?]; | |
| const port_part = input[colon_index.? + 1 ..]; | |
| const port = try fmt.parseInt(u16, port_part, 10); | |
| // Use resolveIp to handle IPv4, IPv6, and potentially interface names for scope IDs | |
| // Note: This performs DNS lookup if a hostname is given. | |
| // For a pure IP proxy, use parseIp directly if preferred. | |
| // We use getAddressList to get potentially multiple IPs and just pick the first. | |
| var addr_list = try net.getAddressList(allocator, host_part, port); | |
| defer addr_list.deinit(); | |
| if (addr_list.addrs.len == 0) { | |
| log.err("Could not resolve host '{s}'", .{host_part}); | |
| return error.HostNotFound; | |
| } | |
| // Return the first resolved address | |
| return addr_list.addrs[0]; | |
| } | |
| // --- TCP Proxy Logic --- | |
| // Context for the copy operation between two streams | |
| const CopyContext = struct { | |
| reader: net.Stream.Reader, | |
| writer: net.Stream.Writer, | |
| // Reader and writer are the streams to copy data between | |
| name: []const u8, // For logging (e.g., "client->backend") | |
| }; | |
| // Function run by threads to copy data in one direction | |
| fn copyStreamFn(context: *CopyContext) void { | |
| const buffer_size = 4096; | |
| var buffer: [buffer_size]u8 = undefined; | |
| log.debug("Starting copy task: {s}", .{context.name}); | |
| while (true) { | |
| const bytes_read = context.reader.read(&buffer) catch |err| { | |
| switch (err) { | |
| else => { | |
| log.err("Copy task '{s}' failed reading: {any}", .{ context.name, err }); | |
| return; // Exit thread on error | |
| }, | |
| } | |
| }; | |
| if (bytes_read == 0) { | |
| // Possible with non-blocking sockets, but typically indicates EOF here with blocking IO | |
| log.debug("Copy task '{s}' read 0 bytes, assuming EOF.", .{context.name}); | |
| return; | |
| } | |
| context.writer.writeAll(buffer[0..bytes_read]) catch |err| { | |
| log.err("Copy task '{s}' failed writing: {any}", .{ context.name, err }); | |
| return; // Exit thread on error | |
| }; | |
| } | |
| } | |
| // Handles a single accepted TCP client connection | |
| fn handleTcpClient(client_conn: net.Server.Connection, backend_addr: net.Address) void { | |
| // Ensure client stream is closed when this function returns | |
| defer client_conn.stream.close(); | |
| log.info("Accepted connection from {}", .{client_conn.address}); | |
| // Connect to the backend server | |
| var backend_stream = net.tcpConnectToAddress(backend_addr) catch |err| { | |
| log.err("Failed to connect to backend {}: {any}", .{ backend_addr, err }); | |
| return; // Close client connection implicitly via defer | |
| }; | |
| // Ensure backend stream is closed | |
| defer backend_stream.close(); | |
| log.info("Connected to backend {} for client {}", .{ backend_addr, client_conn.address }); | |
| // Prepare contexts for bidirectional copy | |
| var ctx_client_to_backend = CopyContext{ | |
| .reader = client_conn.stream.reader(), | |
| .writer = backend_stream.writer(), | |
| .name = "client->backend", | |
| }; | |
| var ctx_backend_to_client = CopyContext{ | |
| .reader = backend_stream.reader(), | |
| .writer = client_conn.stream.writer(), | |
| .name = "backend->client", | |
| }; | |
| // Spawn two threads for bidirectional copying | |
| // Note: Error handling for thread spawning omitted for brevity | |
| const thread_c2b = Thread.spawn(.{}, copyStreamFn, .{&ctx_client_to_backend}) catch |err| { | |
| log.err("Failed to spawn client->backend copy thread: {any}", .{err}); | |
| return; | |
| }; | |
| const thread_b2c = Thread.spawn(.{}, copyStreamFn, .{&ctx_backend_to_client}) catch |err| { | |
| log.err("Failed to spawn backend->client copy thread: {any}", .{err}); | |
| // We should ideally signal the other thread to stop here, but joining will wait | |
| thread_c2b.join(); // Wait for the first thread if the second failed to spawn | |
| return; | |
| }; | |
| // Wait for both copy threads to complete | |
| thread_c2b.join(); | |
| thread_b2c.join(); | |
| log.info("Closing connection for client {}", .{client_conn.address}); | |
| } | |
| // Main loop for the TCP proxy server | |
| fn runTcpProxy(frontend_addr: net.Address, backend_addr: net.Address) !void { | |
| var server = try frontend_addr.listen(.{ .reuse_address = true }); | |
| defer server.deinit(); | |
| log.info("TCP Proxy listening on {}", .{server.listen_address}); | |
| log.info("Forwarding connections to {}", .{backend_addr}); | |
| while (true) { | |
| const connection = server.accept() catch |err| { | |
| log.err("Failed to accept connection: {any}", .{err}); | |
| // Decide whether to continue or exit based on error type | |
| if (err == error.NetworkUnreachable or err == error.ConnectionAborted) { | |
| continue; // Transient? Keep trying. | |
| } else { | |
| return err; // Fatal error | |
| } | |
| }; | |
| // Spawn a new thread to handle this client connection | |
| _ = try Thread.spawn(.{}, handleTcpClient, .{ connection, backend_addr }); | |
| // Note: We detach the thread here. In a real app, you might want a thread pool | |
| // or track threads for graceful shutdown. | |
| } | |
| } | |
| // --- UDP Proxy Logic --- | |
| const MAX_UDP_PAYLOAD = 65507; | |
| // Very simple UDP proxy - forwards client->backend and backend->last_client | |
| // WARNING: Limitations described above apply. Unsuitable for multiple concurrent clients. | |
| fn runUdpProxy(frontend_addr: net.Address, backend_addr: net.Address) !void { | |
| // Create the listening UDP socket using posix calls for recvfrom/sendto | |
| const proxy_sock_fd = try posix.socket(frontend_addr.any.family, posix.SOCK.DGRAM, 0); | |
| // Ensure socket is closed on function exit or error | |
| errdefer posix.close(proxy_sock_fd); | |
| defer posix.close(proxy_sock_fd); | |
| // Enable address reuse | |
| try posix.setsockopt(proxy_sock_fd, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); | |
| // Bind the socket to the frontend address | |
| try posix.bind(proxy_sock_fd, &frontend_addr.any, frontend_addr.getOsSockLen()); | |
| log.info("UDP Proxy listening on {}", .{frontend_addr}); | |
| log.info("Forwarding packets to {}", .{backend_addr}); | |
| var buffer: [MAX_UDP_PAYLOAD]u8 = undefined; | |
| var last_client_addr: ?net.Address = null; // Stores the address of the most recent client | |
| // Storage for the source address received by recvfrom | |
| var src_addr_storage: net.Address = undefined; | |
| var src_addr_len: posix.socklen_t = @sizeOf(@TypeOf(src_addr_storage.any)); // Size of the underlying storage | |
| while (true) { | |
| const received_bytes = posix.recvfrom( | |
| proxy_sock_fd, | |
| &buffer, | |
| 0, // flags | |
| @ptrCast(&src_addr_storage), | |
| &src_addr_len, | |
| ) catch |err| { | |
| log.err("UDP recvfrom failed: {any}", .{err}); | |
| // Depending on the error, might want to continue or break | |
| if (err == error.WouldBlock) continue; // Expected for non-blocking, but we're blocking | |
| return err; // Treat other errors as fatal for simplicity | |
| }; | |
| if (received_bytes == 0) { | |
| continue; // Ignore empty datagrams | |
| } | |
| // Convert the raw sockaddr to a Zig net.Address for easier comparison | |
| const src_addr = net.Address.initPosix(@alignCast(@ptrCast(&src_addr_storage))); | |
| // Check if the packet came from the backend server | |
| if (net.Address.eql(src_addr, backend_addr)) { | |
| // Packet is from the backend | |
| if (last_client_addr) |client_addr| { | |
| // Forward to the last known client | |
| log.debug("UDP Backend -> Client ({}) : {} bytes", .{ client_addr, received_bytes }); | |
| _ = posix.sendto( | |
| proxy_sock_fd, | |
| buffer[0..received_bytes], | |
| 0, // flags | |
| &client_addr.any, | |
| client_addr.getOsSockLen(), | |
| ) catch |err| { | |
| log.warn("UDP sendto (to client {}) failed: {any}", .{ client_addr, err }); | |
| // Continue, client might be gone | |
| }; | |
| } else { | |
| log.warn("UDP Received packet from backend, but no client address known. Dropping.", .{}); | |
| } | |
| } else { | |
| // Packet is from a client | |
| log.debug("UDP Client ({}) -> Backend : {} bytes", .{ src_addr, received_bytes }); | |
| last_client_addr = src_addr; // Remember this client | |
| // Forward to the backend server | |
| _ = posix.sendto( | |
| proxy_sock_fd, | |
| buffer[0..received_bytes], | |
| 0, // flags | |
| &backend_addr.any, | |
| backend_addr.getOsSockLen(), | |
| ) catch |err| { | |
| log.warn("UDP sendto (to backend {}) failed: {any}", .{ backend_addr, err }); | |
| // Continue, backend might be temporarily down | |
| }; | |
| } | |
| } | |
| } | |
| // --- Main Function --- | |
| fn usage(exe_name: []const u8) void { | |
| std.debug.print( | |
| \\Usage: {s} <tcp|udp> <frontend_host:port> <backend_host:port> | |
| \\ | |
| \\Examples: | |
| \\ {s} tcp 0.0.0.0:8080 192.168.1.100:80 | |
| \\ {s} udp 127.0.0.1:5353 8.8.8.8:53 | |
| \\ | |
| , .{ exe_name, exe_name, exe_name }); | |
| } | |
| pub fn main() !void { | |
| // Setup allocator | |
| var gpa = std.heap.GeneralPurposeAllocator(.{}){}; | |
| defer _ = gpa.deinit(); | |
| const allocator = gpa.allocator(); | |
| // Parse arguments | |
| var args_iter = try process.argsWithAllocator(allocator); | |
| defer args_iter.deinit(); | |
| const exe_name = args_iter.next().?; // Skip executable name | |
| const protocol_str = args_iter.next() orelse { | |
| usage(exe_name); | |
| return error.MissingArgument; | |
| }; | |
| const frontend_str = args_iter.next() orelse { | |
| usage(exe_name); | |
| return error.MissingArgument; | |
| }; | |
| const backend_str = args_iter.next() orelse { | |
| usage(exe_name); | |
| return error.MissingArgument; | |
| }; | |
| if (args_iter.next() != null) { | |
| log.warn("Ignoring extra arguments", .{}); | |
| } | |
| // Validate protocol | |
| var args: Args = undefined; | |
| if (std.mem.eql(u8, protocol_str, "tcp")) { | |
| args.protocol = .tcp; | |
| } else if (std.mem.eql(u8, protocol_str, "udp")) { | |
| args.protocol = .udp; | |
| } else { | |
| log.err("Invalid protocol '{s}'. Must be 'tcp' or 'udp'.", .{protocol_str}); | |
| usage(exe_name); | |
| return error.InvalidProtocol; | |
| } | |
| // Parse addresses | |
| args.frontend_addr = parseAddr(allocator, frontend_str) catch |err| { | |
| log.err("Invalid frontend address '{s}': {any}", .{ frontend_str, err }); | |
| return err; | |
| }; | |
| args.backend_addr = parseAddr(allocator, backend_str) catch |err| { | |
| log.err("Invalid backend address '{s}': {any}", .{ backend_str, err }); | |
| return err; | |
| }; | |
| // Dispatch to the correct proxy function | |
| switch (args.protocol) { | |
| .tcp => try runTcpProxy(args.frontend_addr, args.backend_addr), | |
| .udp => try runUdpProxy(args.frontend_addr, args.backend_addr), | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment