Skip to content

Instantly share code, notes, and snippets.

@jlucaso1
Created March 28, 2025 02:07
Show Gist options
  • Select an option

  • Save jlucaso1/f30286614a0bfc11a7d4f1f6efa400a3 to your computer and use it in GitHub Desktop.

Select an option

Save jlucaso1/f30286614a0bfc11a7d4f1f6efa400a3 to your computer and use it in GitHub Desktop.
A tcp and udp reverse proxy in zig
const std = @import("std");
const net = std.net;
const mem = std.mem;
const fmt = std.fmt;
const process = std.process;
const log = std.log.scoped(.proxy);
const posix = std.posix; // For UDP socket operations
const Thread = std.Thread;
const io = std.io;
const Allocator = mem.Allocator;
// Structure to hold parsed arguments
const Args = struct {
protocol: Protocol,
frontend_addr: net.Address,
backend_addr: net.Address,
const Protocol = enum { tcp, udp };
};
// Helper to parse "host:port" strings into a net.Address
fn parseAddr(allocator: Allocator, input: []const u8) !net.Address {
const colon_index = mem.lastIndexOfScalar(u8, input, ':');
if (colon_index == null) {
log.err("Address '{s}' must be in host:port format", .{input});
return error.InvalidAddressFormat;
}
const host_part = input[0..colon_index.?];
const port_part = input[colon_index.? + 1 ..];
const port = try fmt.parseInt(u16, port_part, 10);
// Use resolveIp to handle IPv4, IPv6, and potentially interface names for scope IDs
// Note: This performs DNS lookup if a hostname is given.
// For a pure IP proxy, use parseIp directly if preferred.
// We use getAddressList to get potentially multiple IPs and just pick the first.
var addr_list = try net.getAddressList(allocator, host_part, port);
defer addr_list.deinit();
if (addr_list.addrs.len == 0) {
log.err("Could not resolve host '{s}'", .{host_part});
return error.HostNotFound;
}
// Return the first resolved address
return addr_list.addrs[0];
}
// --- TCP Proxy Logic ---
// Context for the copy operation between two streams
const CopyContext = struct {
reader: net.Stream.Reader,
writer: net.Stream.Writer,
// Reader and writer are the streams to copy data between
name: []const u8, // For logging (e.g., "client->backend")
};
// Function run by threads to copy data in one direction
fn copyStreamFn(context: *CopyContext) void {
const buffer_size = 4096;
var buffer: [buffer_size]u8 = undefined;
log.debug("Starting copy task: {s}", .{context.name});
while (true) {
const bytes_read = context.reader.read(&buffer) catch |err| {
switch (err) {
else => {
log.err("Copy task '{s}' failed reading: {any}", .{ context.name, err });
return; // Exit thread on error
},
}
};
if (bytes_read == 0) {
// Possible with non-blocking sockets, but typically indicates EOF here with blocking IO
log.debug("Copy task '{s}' read 0 bytes, assuming EOF.", .{context.name});
return;
}
context.writer.writeAll(buffer[0..bytes_read]) catch |err| {
log.err("Copy task '{s}' failed writing: {any}", .{ context.name, err });
return; // Exit thread on error
};
}
}
// Handles a single accepted TCP client connection
fn handleTcpClient(client_conn: net.Server.Connection, backend_addr: net.Address) void {
// Ensure client stream is closed when this function returns
defer client_conn.stream.close();
log.info("Accepted connection from {}", .{client_conn.address});
// Connect to the backend server
var backend_stream = net.tcpConnectToAddress(backend_addr) catch |err| {
log.err("Failed to connect to backend {}: {any}", .{ backend_addr, err });
return; // Close client connection implicitly via defer
};
// Ensure backend stream is closed
defer backend_stream.close();
log.info("Connected to backend {} for client {}", .{ backend_addr, client_conn.address });
// Prepare contexts for bidirectional copy
var ctx_client_to_backend = CopyContext{
.reader = client_conn.stream.reader(),
.writer = backend_stream.writer(),
.name = "client->backend",
};
var ctx_backend_to_client = CopyContext{
.reader = backend_stream.reader(),
.writer = client_conn.stream.writer(),
.name = "backend->client",
};
// Spawn two threads for bidirectional copying
// Note: Error handling for thread spawning omitted for brevity
const thread_c2b = Thread.spawn(.{}, copyStreamFn, .{&ctx_client_to_backend}) catch |err| {
log.err("Failed to spawn client->backend copy thread: {any}", .{err});
return;
};
const thread_b2c = Thread.spawn(.{}, copyStreamFn, .{&ctx_backend_to_client}) catch |err| {
log.err("Failed to spawn backend->client copy thread: {any}", .{err});
// We should ideally signal the other thread to stop here, but joining will wait
thread_c2b.join(); // Wait for the first thread if the second failed to spawn
return;
};
// Wait for both copy threads to complete
thread_c2b.join();
thread_b2c.join();
log.info("Closing connection for client {}", .{client_conn.address});
}
// Main loop for the TCP proxy server
fn runTcpProxy(frontend_addr: net.Address, backend_addr: net.Address) !void {
var server = try frontend_addr.listen(.{ .reuse_address = true });
defer server.deinit();
log.info("TCP Proxy listening on {}", .{server.listen_address});
log.info("Forwarding connections to {}", .{backend_addr});
while (true) {
const connection = server.accept() catch |err| {
log.err("Failed to accept connection: {any}", .{err});
// Decide whether to continue or exit based on error type
if (err == error.NetworkUnreachable or err == error.ConnectionAborted) {
continue; // Transient? Keep trying.
} else {
return err; // Fatal error
}
};
// Spawn a new thread to handle this client connection
_ = try Thread.spawn(.{}, handleTcpClient, .{ connection, backend_addr });
// Note: We detach the thread here. In a real app, you might want a thread pool
// or track threads for graceful shutdown.
}
}
// --- UDP Proxy Logic ---
const MAX_UDP_PAYLOAD = 65507;
// Very simple UDP proxy - forwards client->backend and backend->last_client
// WARNING: Limitations described above apply. Unsuitable for multiple concurrent clients.
fn runUdpProxy(frontend_addr: net.Address, backend_addr: net.Address) !void {
// Create the listening UDP socket using posix calls for recvfrom/sendto
const proxy_sock_fd = try posix.socket(frontend_addr.any.family, posix.SOCK.DGRAM, 0);
// Ensure socket is closed on function exit or error
errdefer posix.close(proxy_sock_fd);
defer posix.close(proxy_sock_fd);
// Enable address reuse
try posix.setsockopt(proxy_sock_fd, posix.SOL.SOCKET, posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1)));
// Bind the socket to the frontend address
try posix.bind(proxy_sock_fd, &frontend_addr.any, frontend_addr.getOsSockLen());
log.info("UDP Proxy listening on {}", .{frontend_addr});
log.info("Forwarding packets to {}", .{backend_addr});
var buffer: [MAX_UDP_PAYLOAD]u8 = undefined;
var last_client_addr: ?net.Address = null; // Stores the address of the most recent client
// Storage for the source address received by recvfrom
var src_addr_storage: net.Address = undefined;
var src_addr_len: posix.socklen_t = @sizeOf(@TypeOf(src_addr_storage.any)); // Size of the underlying storage
while (true) {
const received_bytes = posix.recvfrom(
proxy_sock_fd,
&buffer,
0, // flags
@ptrCast(&src_addr_storage),
&src_addr_len,
) catch |err| {
log.err("UDP recvfrom failed: {any}", .{err});
// Depending on the error, might want to continue or break
if (err == error.WouldBlock) continue; // Expected for non-blocking, but we're blocking
return err; // Treat other errors as fatal for simplicity
};
if (received_bytes == 0) {
continue; // Ignore empty datagrams
}
// Convert the raw sockaddr to a Zig net.Address for easier comparison
const src_addr = net.Address.initPosix(@alignCast(@ptrCast(&src_addr_storage)));
// Check if the packet came from the backend server
if (net.Address.eql(src_addr, backend_addr)) {
// Packet is from the backend
if (last_client_addr) |client_addr| {
// Forward to the last known client
log.debug("UDP Backend -> Client ({}) : {} bytes", .{ client_addr, received_bytes });
_ = posix.sendto(
proxy_sock_fd,
buffer[0..received_bytes],
0, // flags
&client_addr.any,
client_addr.getOsSockLen(),
) catch |err| {
log.warn("UDP sendto (to client {}) failed: {any}", .{ client_addr, err });
// Continue, client might be gone
};
} else {
log.warn("UDP Received packet from backend, but no client address known. Dropping.", .{});
}
} else {
// Packet is from a client
log.debug("UDP Client ({}) -> Backend : {} bytes", .{ src_addr, received_bytes });
last_client_addr = src_addr; // Remember this client
// Forward to the backend server
_ = posix.sendto(
proxy_sock_fd,
buffer[0..received_bytes],
0, // flags
&backend_addr.any,
backend_addr.getOsSockLen(),
) catch |err| {
log.warn("UDP sendto (to backend {}) failed: {any}", .{ backend_addr, err });
// Continue, backend might be temporarily down
};
}
}
}
// --- Main Function ---
fn usage(exe_name: []const u8) void {
std.debug.print(
\\Usage: {s} <tcp|udp> <frontend_host:port> <backend_host:port>
\\
\\Examples:
\\ {s} tcp 0.0.0.0:8080 192.168.1.100:80
\\ {s} udp 127.0.0.1:5353 8.8.8.8:53
\\
, .{ exe_name, exe_name, exe_name });
}
pub fn main() !void {
// Setup allocator
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
// Parse arguments
var args_iter = try process.argsWithAllocator(allocator);
defer args_iter.deinit();
const exe_name = args_iter.next().?; // Skip executable name
const protocol_str = args_iter.next() orelse {
usage(exe_name);
return error.MissingArgument;
};
const frontend_str = args_iter.next() orelse {
usage(exe_name);
return error.MissingArgument;
};
const backend_str = args_iter.next() orelse {
usage(exe_name);
return error.MissingArgument;
};
if (args_iter.next() != null) {
log.warn("Ignoring extra arguments", .{});
}
// Validate protocol
var args: Args = undefined;
if (std.mem.eql(u8, protocol_str, "tcp")) {
args.protocol = .tcp;
} else if (std.mem.eql(u8, protocol_str, "udp")) {
args.protocol = .udp;
} else {
log.err("Invalid protocol '{s}'. Must be 'tcp' or 'udp'.", .{protocol_str});
usage(exe_name);
return error.InvalidProtocol;
}
// Parse addresses
args.frontend_addr = parseAddr(allocator, frontend_str) catch |err| {
log.err("Invalid frontend address '{s}': {any}", .{ frontend_str, err });
return err;
};
args.backend_addr = parseAddr(allocator, backend_str) catch |err| {
log.err("Invalid backend address '{s}': {any}", .{ backend_str, err });
return err;
};
// Dispatch to the correct proxy function
switch (args.protocol) {
.tcp => try runTcpProxy(args.frontend_addr, args.backend_addr),
.udp => try runUdpProxy(args.frontend_addr, args.backend_addr),
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment