diff --git a/common/net/listener.go b/common/net/listener.go new file mode 100644 index 00000000..0621b6c2 --- /dev/null +++ b/common/net/listener.go @@ -0,0 +1,90 @@ +package net + +import ( + "context" + "net" + "sync" +) + +type handleContextListener struct { + net.Listener + ctx context.Context + cancel context.CancelFunc + conns chan net.Conn + err error + once sync.Once + handle func(context.Context, net.Conn) (net.Conn, error) + panicLog func(any) +} + +func (l *handleContextListener) init() { + go func() { + for { + c, err := l.Listener.Accept() + if err != nil { + l.err = err + close(l.conns) + return + } + go func() { + defer func() { + if r := recover(); r != nil { + if l.panicLog != nil { + l.panicLog(r) + } + } + }() + if c, err := l.handle(l.ctx, c); err == nil { + l.conns <- c + } else { + // handle failed, close the underlying connection. + _ = c.Close() + } + }() + } + }() +} + +func (l *handleContextListener) Accept() (net.Conn, error) { + l.once.Do(l.init) + if c, ok := <-l.conns; ok { + return c, nil + } + return nil, l.err +} + +func (l *handleContextListener) Close() error { + l.cancel() + l.once.Do(func() { // l.init has not been called yet, so close related resources directly. + l.err = net.ErrClosed + close(l.conns) + }) + defer func() { + // at here, listener has been closed, so we should close all connections in the channel + for c := range l.conns { + go func(c net.Conn) { + defer func() { + if r := recover(); r != nil { + if l.panicLog != nil { + l.panicLog(r) + } + } + }() + _ = c.Close() + }(c) + } + }() + return l.Listener.Close() +} + +func NewHandleContextListener(ctx context.Context, l net.Listener, handle func(context.Context, net.Conn) (net.Conn, error), panicLog func(any)) net.Listener { + ctx, cancel := context.WithCancel(ctx) + return &handleContextListener{ + Listener: l, + ctx: ctx, + cancel: cancel, + conns: make(chan net.Conn), + handle: handle, + panicLog: panicLog, + } +} diff --git a/listener/reality/reality.go b/listener/reality/reality.go index 16ccc01c..036bcf28 100644 --- a/listener/reality/reality.go +++ b/listener/reality/reality.go @@ -9,6 +9,7 @@ import ( "net" "time" + N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/listener/inner" "github.com/metacubex/mihomo/log" @@ -79,11 +80,17 @@ type Builder struct { } func (b Builder) NewListener(l net.Listener) net.Listener { - l = utls.NewRealityListener(l, b.realityConfig) - // Due to low implementation quality, the reality server intercepted half close and caused memory leaks. - // We fixed it by calling Close() directly. - l = realityListenerWrapper{l} - return l + return N.NewHandleContextListener(context.Background(), l, func(ctx context.Context, conn net.Conn) (net.Conn, error) { + c, err := utls.RealityServer(ctx, conn, b.realityConfig) + if err != nil { + return nil, err + } + // Due to low implementation quality, the reality server intercepted half-close and caused memory leaks. + // We fixed it by calling Close() directly. + return realityConnWrapper{c}, nil + }, func(a any) { + log.Errorln("reality server panic: %s", a) + }) } type realityConnWrapper struct { @@ -97,15 +104,3 @@ func (c realityConnWrapper) Upstream() any { func (c realityConnWrapper) CloseWrite() error { return c.Close() } - -type realityListenerWrapper struct { - net.Listener -} - -func (l realityListenerWrapper) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - return realityConnWrapper{c.(*utls.Conn)}, nil -}