diff --git a/common/once/oncefunc.go b/common/once/oncefunc.go new file mode 100644 index 00000000..80c00f88 --- /dev/null +++ b/common/once/oncefunc.go @@ -0,0 +1,102 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package once + +import "sync" + +// OnceFunc returns a function that invokes f only once. The returned function +// may be called concurrently. +// +// If f panics, the returned function will panic with the same value on every call. +func OnceFunc(f func()) func() { + var ( + once sync.Once + valid bool + p any + ) + // Construct the inner closure just once to reduce costs on the fast path. + g := func() { + defer func() { + p = recover() + if !valid { + // Re-panic immediately so on the first call the user gets a + // complete stack trace into f. + panic(p) + } + }() + f() + f = nil // Do not keep f alive after invoking it. + valid = true // Set only if f does not panic. + } + return func() { + once.Do(g) + if !valid { + panic(p) + } + } +} + +// OnceValue returns a function that invokes f only once and returns the value +// returned by f. The returned function may be called concurrently. +// +// If f panics, the returned function will panic with the same value on every call. +func OnceValue[T any](f func() T) func() T { + var ( + once sync.Once + valid bool + p any + result T + ) + g := func() { + defer func() { + p = recover() + if !valid { + panic(p) + } + }() + result = f() + f = nil + valid = true + } + return func() T { + once.Do(g) + if !valid { + panic(p) + } + return result + } +} + +// OnceValues returns a function that invokes f only once and returns the values +// returned by f. The returned function may be called concurrently. +// +// If f panics, the returned function will panic with the same value on every call. +func OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) { + var ( + once sync.Once + valid bool + p any + r1 T1 + r2 T2 + ) + g := func() { + defer func() { + p = recover() + if !valid { + panic(p) + } + }() + r1, r2 = f() + f = nil + valid = true + } + return func() (T1, T2) { + once.Do(g) + if !valid { + panic(p) + } + return r1, r2 + } +} diff --git a/component/tls/reality.go b/component/tls/reality.go index eee37384..6a5cdc5f 100644 --- a/component/tls/reality.go +++ b/component/tls/reality.go @@ -37,9 +37,8 @@ type RealityConfig struct { ShortID [RealityMaxShortIDLen]byte } -func GetRealityConn(ctx context.Context, conn net.Conn, clientFingerprint string, tlsConfig *tls.Config, realityConfig *RealityConfig) (net.Conn, error) { - retry := 0 - for fingerprint, exists := GetFingerprint(clientFingerprint); exists; retry++ { +func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHelloID, tlsConfig *tls.Config, realityConfig *RealityConfig) (net.Conn, error) { + for retry := 0; ; retry++ { verifier := &realityVerifier{ serverName: tlsConfig.ServerName, } @@ -151,7 +150,6 @@ func GetRealityConn(ctx context.Context, conn net.Conn, clientFingerprint string return uConn, nil } - return nil, errors.New("unknown uTLS fingerprint") } func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) { diff --git a/component/tls/utls.go b/component/tls/utls.go index 7b67fef1..80b37f38 100644 --- a/component/tls/utls.go +++ b/component/tls/utls.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net" + "github.com/metacubex/mihomo/common/once" "github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/log" @@ -11,46 +12,44 @@ import ( "github.com/mroth/weightedrand/v2" ) +type Conn = utls.Conn type UConn = utls.UConn +type UClientHelloID = utls.ClientHelloID const VersionTLS13 = utls.VersionTLS13 -type UClientHelloID struct { - *utls.ClientHelloID +func Client(c net.Conn, config *utls.Config) *Conn { + return utls.Client(c, config) } -var initRandomFingerprint UClientHelloID -var initUtlsClient string - func UClient(c net.Conn, config *utls.Config, fingerprint UClientHelloID) *UConn { - return utls.UClient(c, config, *fingerprint.ClientHelloID) + return utls.UClient(c, config, fingerprint) } -func GetFingerprint(ClientFingerprint string) (UClientHelloID, bool) { - if ClientFingerprint == "none" { +func GetFingerprint(clientFingerprint string) (UClientHelloID, bool) { + if len(clientFingerprint) == 0 { + clientFingerprint = globalFingerprint + } + if len(clientFingerprint) == 0 || clientFingerprint == "none" { return UClientHelloID{}, false } - if initRandomFingerprint.ClientHelloID == nil { - initRandomFingerprint, _ = RollFingerprint() + if clientFingerprint == "random" { + fingerprint := randomFingerprint() + log.Debugln("use initial random HelloID:%s", fingerprint.Client) + return fingerprint, true } - if ClientFingerprint == "random" { - log.Debugln("use initial random HelloID:%s", initRandomFingerprint.Client) - return initRandomFingerprint, true - } - - fingerprint, ok := Fingerprints[ClientFingerprint] - if ok { + if fingerprint, ok := fingerprints[clientFingerprint]; ok { log.Debugln("use specified fingerprint:%s", fingerprint.Client) - return fingerprint, ok + return fingerprint, true } else { - log.Warnln("wrong ClientFingerprint:%s", ClientFingerprint) + log.Warnln("wrong clientFingerprint:%s", clientFingerprint) return UClientHelloID{}, false } } -func RollFingerprint() (UClientHelloID, bool) { +var randomFingerprint = once.OnceValue(func() UClientHelloID { chooser, _ := weightedrand.NewChooser( weightedrand.NewChoice("chrome", 6), weightedrand.NewChoice("safari", 3), @@ -59,26 +58,29 @@ func RollFingerprint() (UClientHelloID, bool) { ) initClient := chooser.Pick() log.Debugln("initial random HelloID:%s", initClient) - fingerprint, ok := Fingerprints[initClient] - return fingerprint, ok -} + fingerprint, ok := fingerprints[initClient] + if !ok { + log.Warnln("error in initial random HelloID:%s", initClient) + } + return fingerprint +}) -var Fingerprints = map[string]UClientHelloID{ - "chrome": {&utls.HelloChrome_Auto}, - "chrome_psk": {&utls.HelloChrome_100_PSK}, - "chrome_psk_shuffle": {&utls.HelloChrome_106_Shuffle}, - "chrome_padding_psk_shuffle": {&utls.HelloChrome_114_Padding_PSK_Shuf}, - "chrome_pq": {&utls.HelloChrome_115_PQ}, - "chrome_pq_psk": {&utls.HelloChrome_115_PQ_PSK}, - "firefox": {&utls.HelloFirefox_Auto}, - "safari": {&utls.HelloSafari_Auto}, - "ios": {&utls.HelloIOS_Auto}, - "android": {&utls.HelloAndroid_11_OkHttp}, - "edge": {&utls.HelloEdge_Auto}, - "360": {&utls.Hello360_Auto}, - "qq": {&utls.HelloQQ_Auto}, - "random": {nil}, - "randomized": {nil}, +var fingerprints = map[string]UClientHelloID{ + "chrome": utls.HelloChrome_Auto, + "chrome_psk": utls.HelloChrome_100_PSK, + "chrome_psk_shuffle": utls.HelloChrome_106_Shuffle, + "chrome_padding_psk_shuffle": utls.HelloChrome_114_Padding_PSK_Shuf, + "chrome_pq": utls.HelloChrome_115_PQ, + "chrome_pq_psk": utls.HelloChrome_115_PQ_PSK, + "firefox": utls.HelloFirefox_Auto, + "safari": utls.HelloSafari_Auto, + "ios": utls.HelloIOS_Auto, + "android": utls.HelloAndroid_11_OkHttp, + "edge": utls.HelloEdge_Auto, + "360": utls.Hello360_Auto, + "qq": utls.HelloQQ_Auto, + "random": {}, + "randomized": utls.HelloRandomized, } func init() { @@ -88,7 +90,7 @@ func init() { randomized := utls.HelloRandomized randomized.Seed, _ = utls.NewPRNGSeed() randomized.Weights = &weights - Fingerprints["randomized"] = UClientHelloID{&randomized} + fingerprints["randomized"] = randomized } func UCertificates(it tls.Certificate) utls.Certificate { @@ -154,14 +156,12 @@ func BuildWebsocketHandshakeState(c *UConn) error { return nil } -func SetGlobalUtlsClient(Client string) { - initUtlsClient = Client -} +var globalFingerprint string -func HaveGlobalFingerprint() bool { - return len(initUtlsClient) != 0 && initUtlsClient != "none" +func SetGlobalFingerprint(fingerprint string) { + globalFingerprint = fingerprint } func GetGlobalFingerprint() string { - return initUtlsClient + return globalFingerprint } diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 670fa7b8..dd5f0912 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -454,7 +454,7 @@ func updateGeneral(general *config.General, logging bool) { mihomoHttp.SetUA(general.GlobalUA) resource.SetETag(general.ETagSupport) - tlsC.SetGlobalUtlsClient(general.GlobalClientFingerprint) + tlsC.SetGlobalFingerprint(general.GlobalClientFingerprint) } func updateUsers(users []auth.AuthUser) { diff --git a/transport/gun/gun.go b/transport/gun/gun.go index 68f4b2d9..13d4046d 100644 --- a/transport/gun/gun.go +++ b/transport/gun/gun.go @@ -237,25 +237,19 @@ func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, clientFingerprint stri return pconn, nil } - clientFingerprint := clientFingerprint - if tlsC.HaveGlobalFingerprint() && len(clientFingerprint) == 0 { - clientFingerprint = tlsC.GetGlobalFingerprint() - } - if len(clientFingerprint) != 0 { + if clientFingerprint, ok := tlsC.GetFingerprint(clientFingerprint); ok { if realityConfig == nil { - if fingerprint, exists := tlsC.GetFingerprint(clientFingerprint); exists { - utlsConn := tlsC.UClient(pconn, tlsC.UConfig(cfg), fingerprint) - if err := utlsConn.HandshakeContext(ctx); err != nil { - pconn.Close() - return nil, err - } - state := utlsConn.ConnectionState() - if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { - utlsConn.Close() - return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS) - } - return utlsConn, nil + tlsConn := tlsC.UClient(pconn, tlsC.UConfig(cfg), clientFingerprint) + if err := tlsConn.HandshakeContext(ctx); err != nil { + pconn.Close() + return nil, err } + state := tlsConn.ConnectionState() + if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { + tlsConn.Close() + return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS) + } + return tlsConn, nil } else { realityConn, err := tlsC.GetRealityConn(ctx, pconn, clientFingerprint, cfg, realityConfig) if err != nil { diff --git a/transport/sing-shadowtls/shadowtls.go b/transport/sing-shadowtls/shadowtls.go index 2f916653..904bcd63 100644 --- a/transport/sing-shadowtls/shadowtls.go +++ b/transport/sing-shadowtls/shadowtls.go @@ -10,7 +10,6 @@ import ( "github.com/metacubex/mihomo/log" "github.com/metacubex/sing-shadowtls" - utls "github.com/metacubex/utls" "golang.org/x/exp/slices" ) @@ -67,26 +66,21 @@ func uTLSHandshakeFunc(config *tls.Config, clientFingerprint string) shadowtls.T return func(ctx context.Context, conn net.Conn, sessionIDGenerator shadowtls.TLSSessionIDGeneratorFunc) error { tlsConfig := tlsC.UConfig(config) tlsConfig.SessionIDGenerator = sessionIDGenerator - clientFingerprint := clientFingerprint - if tlsC.HaveGlobalFingerprint() && len(clientFingerprint) == 0 { - clientFingerprint = tlsC.GetGlobalFingerprint() - } if config.MaxVersion == tls.VersionTLS12 { // for ShadowTLS v1 - clientFingerprint = "" + tlsConn := tlsC.Client(conn, tlsConfig) + return tlsConn.HandshakeContext(ctx) } - if len(clientFingerprint) != 0 { - if fingerprint, exists := tlsC.GetFingerprint(clientFingerprint); exists { - tlsConn := tlsC.UClient(conn, tlsConfig, fingerprint) - if slices.Equal(tlsConfig.NextProtos, WsALPN) { - err := tlsC.BuildWebsocketHandshakeState(tlsConn) - if err != nil { - return err - } + if clientFingerprint, ok := tlsC.GetFingerprint(clientFingerprint); ok { + tlsConn := tlsC.UClient(conn, tlsConfig, clientFingerprint) + if slices.Equal(tlsConfig.NextProtos, WsALPN) { + err := tlsC.BuildWebsocketHandshakeState(tlsConn) + if err != nil { + return err } - return tlsConn.HandshakeContext(ctx) } + return tlsConn.HandshakeContext(ctx) } - tlsConn := utls.Client(conn, tlsConfig) + tlsConn := tlsC.Client(conn, tlsConfig) return tlsConn.HandshakeContext(ctx) } } diff --git a/transport/vmess/tls.go b/transport/vmess/tls.go index 69871bb8..588c159a 100644 --- a/transport/vmess/tls.go +++ b/transport/vmess/tls.go @@ -32,20 +32,14 @@ func StreamTLSConn(ctx context.Context, conn net.Conn, cfg *TLSConfig) (net.Conn return nil, err } - clientFingerprint := cfg.ClientFingerprint - if tlsC.HaveGlobalFingerprint() && len(clientFingerprint) == 0 { - clientFingerprint = tlsC.GetGlobalFingerprint() - } - if len(clientFingerprint) != 0 { + if clientFingerprint, ok := tlsC.GetFingerprint(cfg.ClientFingerprint); ok { if cfg.Reality == nil { - if fingerprint, exists := tlsC.GetFingerprint(clientFingerprint); exists { - utlsConn := tlsC.UClient(conn, tlsC.UConfig(tlsConfig), fingerprint) - err = utlsConn.HandshakeContext(ctx) - if err != nil { - return nil, err - } - return utlsConn, nil + tlsConn := tlsC.UClient(conn, tlsC.UConfig(tlsConfig), clientFingerprint) + err = tlsConn.HandshakeContext(ctx) + if err != nil { + return nil, err } + return tlsConn, nil } else { return tlsC.GetRealityConn(ctx, conn, clientFingerprint, tlsConfig, cfg.Reality) } diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 6a8963fd..7e8886b6 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -351,31 +351,26 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, } if config.ServerName == "" && !config.InsecureSkipVerify { // users must set either ServerName or InsecureSkipVerify in the config. config = config.Clone() - config.ServerName = uri.Host + config.ServerName = c.Host } - clientFingerprint := c.ClientFingerprint - if tlsC.HaveGlobalFingerprint() && len(clientFingerprint) == 0 { - clientFingerprint = tlsC.GetGlobalFingerprint() - } - if len(clientFingerprint) != 0 { - if fingerprint, exists := tlsC.GetFingerprint(clientFingerprint); exists { - utlsConn := tlsC.UClient(conn, tlsC.UConfig(config), fingerprint) - if err = tlsC.BuildWebsocketHandshakeState(utlsConn); err != nil { - return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) - } - conn = utlsConn + if clientFingerprint, ok := tlsC.GetFingerprint(c.ClientFingerprint); ok { + tlsConn := tlsC.UClient(conn, tlsC.UConfig(config), clientFingerprint) + if err = tlsC.BuildWebsocketHandshakeState(tlsConn); err != nil { + return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) } - } else { - conn = tls.Client(conn, config) - } - - if tlsConn, ok := conn.(interface { - HandshakeContext(ctx context.Context) error - }); ok { - if err = tlsConn.HandshakeContext(ctx); err != nil { + err = tlsConn.HandshakeContext(ctx) + if err != nil { return nil, err } + conn = tlsConn + } else { + tlsConn := tls.Client(conn, config) + err = tlsConn.HandshakeContext(ctx) + if err != nil { + return nil, err + } + conn = tlsConn } }