Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions client/iface/wgproxy/bind/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,19 @@ func (p *ProxyBind) Pause() {
}

func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false

ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to convert endpoint address: %v", err)
} else {
p.wgCurrentUsed = ep
log.Errorf("failed to start package redirection: %v", err)
return
}

p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}

func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, errors.New("nil address")
}
p.pausedCond.L.Lock()
p.paused = false

ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
p.wgCurrentUsed = ep

addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}

func (p *ProxyBind) CloseConn() error {
Expand Down Expand Up @@ -225,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
}

func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}

addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}
55 changes: 42 additions & 13 deletions client/iface/wgproxy/ebpf/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ const (
)

var (
localHostNetIP = net.ParseIP("127.0.0.1")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")

serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)

// WGEBPFProxy definition for proxy with EBPF support
Expand Down Expand Up @@ -218,31 +224,54 @@ generatePort:
}

func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,

var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var dstIP net.IP

if endpointAddr.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
dstIP = localHostNetIPv4
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpointAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
dstIP = localHostNetIPv6
}

udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}

err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}

layerBuffer := gopacket.NewSerializeBuffer()
payload := gopacket.Payload(data)

err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {

if _, err := p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
Expand Down
8 changes: 5 additions & 3 deletions client/iface/wgproxy/ebpf/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,14 @@ func (p *ProxyWrapper) Pause() {
}

func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
if endpoint == nil || endpoint.IP == nil {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
p.pausedCond.L.Lock()
p.paused = false

if endpoint != nil && endpoint.IP != nil {
p.wgEndpointCurrentUsedAddr = endpoint
}
p.wgEndpointCurrentUsedAddr = endpoint

p.pausedCond.Signal()
p.pausedCond.L.Unlock()
Expand Down
67 changes: 49 additions & 18 deletions client/iface/wgproxy/udp/rawsocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@ var (
FixLengths: true,
}

localHostNetIPAddr = &net.IPAddr{
localHostNetIPAddrV4 = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
localHostNetIPAddrV6 = &net.IPAddr{
IP: net.ParseIP("::1"),
}
)

type SrcFaker struct {
srcAddr *net.UDPAddr

rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
localHostAddr *net.IPAddr
}

func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
Expand All @@ -44,12 +48,18 @@ func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
return nil, err
}

localHostAddr := localHostNetIPAddrV4
if srcAddr.IP.To4() == nil {
localHostAddr = localHostNetIPAddrV6
}

f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
}

return f, nil
Expand All @@ -72,27 +82,48 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) {
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}

func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer

// Check if source IP is IPv4 or IPv6
if srcAddr.IP.To4() != nil {
// IPv4
ipv4 := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
} else {
// IPv6
ipv6 := &layers.IPv6{
DstIP: net.ParseIP("::1"),
SrcIP: srcAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
}

udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}

err := udpH.SetNetworkLayerForChecksum(ipH)
err := udpH.SetNetworkLayerForChecksum(networkLayer)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
Expand Down
Loading