Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ const (
// This allows for alternate socket hosts for connecting to a session for failover.
// (i.e.) SocketConnectHost1, SocketConnectHost2... must be consecutive and have a matching SocketConnectPort<n>.
//
// Required: Yes for initiators
// Required: Yes for initiators on socket connections
//
// Default: None
//
Expand All @@ -547,7 +547,7 @@ const (
// This allows for alternate socket ports for connecting to a session for failover.
// (i.e.) SocketConnectPort1, SocketConnectPort2... must be consecutive and have a matching SocketConnectHost<n>.
//
// Required: Yes for initiators
// Required: Yes for initiators on socket connections
//
// Default: None
//
Expand Down Expand Up @@ -624,6 +624,27 @@ const (
// Valid Values:
// - Any string
ProxyPassword string = "ProxyPassword"

// WebsocketLocation sets the websocket endpoint to attempt to connect to.
// Setting this would override any SocketConnectHost and SocketConnectPort settings and connect using websocket
//
// Required: No
//
// Default: N/A
//
// Valid Values:
// - A websocket endpoint - eg. wss://example.com/ws
WebsocketLocation string = "WebsocketLocation"

// WebsocketOrigin sets the websocket origin to attempt to connect from.
//
// Required: No
//
// Default: N/A
//
// Valid Values:
// - url - eg. http://localhost/
WebsocketOrigin string = "WebsocketOrigin"
)

const (
Expand Down
92 changes: 88 additions & 4 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,113 @@
package quickfix

import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"time"

"golang.org/x/net/proxy"
"golang.org/x/net/websocket"

"github.com/quickfixgo/quickfix/config"
)

func loadDialerConfig(settings *SessionSettings) (dialer proxy.ContextDialer, err error) {
type Dialer interface {
Dial(ctx context.Context, session *session, attempt int, tlsConfig *tls.Config) (net.Conn, error)
}

type TcpDialer struct {

Check warning on line 36 in dialer.go

View workflow job for this annotation

GitHub Actions / Linter

var-naming: type TcpDialer should be TCPDialer (revive)
ctxDialer proxy.ContextDialer
}

func (d *TcpDialer) Dial(ctx context.Context, session *session, attempt int, tlsConfig *tls.Config) (conn net.Conn, err error) {
address := session.SocketConnectAddress[attempt%len(session.SocketConnectAddress)]
session.log.OnEventf("Connecting to: %v", address)

conn, err = d.ctxDialer.DialContext(ctx, "tcp", address)

if err != nil {
return
} else if tlsConfig != nil {
// Unless InsecureSkipVerify is true, server name config is required for TLS
// to verify the received certificate
if !tlsConfig.InsecureSkipVerify && len(tlsConfig.ServerName) == 0 {
serverName := address
if c := strings.LastIndex(serverName, ":"); c > 0 {
serverName = serverName[:c]
}
tlsConfig.ServerName = serverName
}
tlsConn := tls.Client(conn, tlsConfig)
if err = tlsConn.Handshake(); err != nil {

session.log.OnEventf("Failed handshake: %v", err)
return
}
conn = tlsConn
}

return
}

type WebsocketDialer struct {
wsConfig *websocket.Config
}

func (d *WebsocketDialer) Dial(ctx context.Context, session *session, _ int, tlsConfig *tls.Config) (conn net.Conn, err error) {
session.log.OnEventf("Connecting to: %v", d.wsConfig.Location)

d.wsConfig.TlsConfig = tlsConfig
conn, err = d.wsConfig.DialContext(ctx)
return
}

func loadDialerConfig(settings *SessionSettings) (dialer Dialer, err error) {

if settings.HasSetting(config.WebsocketLocation) {
var location string
location, err = settings.Setting(config.WebsocketLocation)
if err != nil {
return nil, err
}

var origin string
origin, err = settings.Setting(config.WebsocketOrigin)
if err != nil {
return nil, err
}

var wsConfig *websocket.Config
wsConfig, err = websocket.NewConfig(location, origin)
if err != nil {
return nil, err
}

dialer = &WebsocketDialer{
wsConfig: wsConfig,
}
return
}

stdDialer := &net.Dialer{}
dialer = &TcpDialer{
ctxDialer: stdDialer,
}
if settings.HasSetting(config.SocketTimeout) {
timeout, err := settings.DurationSetting(config.SocketTimeout)
if err != nil {
timeoutInt, err := settings.IntSetting(config.SocketTimeout)
if err != nil {
return stdDialer, err
return nil, err
}

stdDialer.Timeout = time.Duration(timeoutInt) * time.Second
} else {
stdDialer.Timeout = timeout
}
}
dialer = stdDialer

if !settings.HasSetting(config.ProxyType) {
return
Expand Down Expand Up @@ -81,7 +163,9 @@
}

if contextDialer, ok := proxyDialer.(proxy.ContextDialer); ok {
dialer = contextDialer
dialer = &TcpDialer{
ctxDialer: contextDialer,
}
} else {
err = fmt.Errorf("proxy does not support context dialer")
return
Expand Down
6 changes: 3 additions & 3 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (s *DialerTestSuite) TestLoadDialerNoSettings() {
dialer, err := loadDialerConfig(s.settings.GlobalSettings())
s.Require().Nil(err)

stdDialer, ok := dialer.(*net.Dialer)
stdDialer, ok := dialer.(*TcpDialer).ctxDialer.(*net.Dialer)
s.Require().True(ok)
s.Require().NotNil(stdDialer)
s.Zero(stdDialer.Timeout)
Expand All @@ -53,7 +53,7 @@ func (s *DialerTestSuite) TestLoadDialerWithTimeout() {
dialer, err := loadDialerConfig(s.settings.GlobalSettings())
s.Require().Nil(err)

stdDialer, ok := dialer.(*net.Dialer)
stdDialer, ok := dialer.(*TcpDialer).ctxDialer.(*net.Dialer)
s.Require().True(ok)
s.Require().NotNil(stdDialer)
s.EqualValues(10*time.Second, stdDialer.Timeout)
Expand All @@ -73,7 +73,7 @@ func (s *DialerTestSuite) TestLoadDialerSocksProxy() {
s.Require().Nil(err)
s.Require().NotNil(dialer)

_, ok := dialer.(*net.Dialer)
_, ok := dialer.(*TcpDialer).ctxDialer.(*net.Dialer)
s.Require().False(ok)
}

Expand Down
35 changes: 8 additions & 27 deletions initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
"bufio"
"context"
"crypto/tls"
"strings"
"sync"
"time"

"golang.org/x/net/proxy"
)

// Initiator initiates connections and processes messages for all sessions.
Expand All @@ -48,12 +45,12 @@
// TODO: move into session factory.
var tlsConfig *tls.Config
if tlsConfig, err = loadTLSConfig(settings); err != nil {
return
return err
}

var dialer proxy.ContextDialer
var dialer Dialer
if dialer, err = loadDialerConfig(settings); err != nil {
return
return err
}

i.wg.Add(1)
Expand Down Expand Up @@ -143,7 +140,7 @@
return true
}

func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer proxy.ContextDialer) {
func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, dialer Dialer) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
Expand Down Expand Up @@ -180,29 +177,13 @@
var msgIn chan fixIn
var msgOut chan []byte

address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)]
session.log.OnEventf("Connecting to: %v", address)

netConn, err := dialer.DialContext(ctx, "tcp", address)
netConn, err := dialer.Dial(ctx, session, connectionAttempt, tlsConfig)
if err != nil {
session.log.OnEventf("Failed to connect: %v", err)
goto reconnect
} else if tlsConfig != nil {
// Unless InsecureSkipVerify is true, server name config is required for TLS
// to verify the received certificate
if !tlsConfig.InsecureSkipVerify && len(tlsConfig.ServerName) == 0 {
serverName := address
if c := strings.LastIndex(serverName, ":"); c > 0 {
serverName = serverName[:c]
}
tlsConfig.ServerName = serverName
}
tlsConn := tls.Client(netConn, tlsConfig)
if err = tlsConn.Handshake(); err != nil {
session.log.OnEventf("Failed handshake: %v", err)
goto reconnect
}
netConn = tlsConn
} else {

Check warning on line 184 in initiator.go

View workflow job for this annotation

GitHub Actions / Linter

superfluous-else: if block ends with a goto statement, so drop this else and outdent its block (revive)
address := netConn.RemoteAddr().String()
session.log.OnEventf("connected to remote address: %v", address)
}

msgIn = make(chan fixIn)
Expand Down
7 changes: 7 additions & 0 deletions session_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,13 @@ func (f sessionFactory) buildInitiatorSettings(session *session, settings *Sessi
func (f sessionFactory) configureSocketConnectAddress(session *session, settings *SessionSettings) (err error) {
session.SocketConnectAddress = []string{}

if !settings.HasSetting(config.SocketConnectHost) {
if !settings.HasSetting(config.WebsocketLocation) {
err = errors.New("SocketConnectHost must be specified if WebsocketLocation is not specified")
}
return
}

var socketConnectHost, socketConnectPort string
for i := 0; ; {

Expand Down
Loading