diff --git a/stun/agent.go b/stun/agent.go index 52ea00c..4f2fd1d 100644 --- a/stun/agent.go +++ b/stun/agent.go @@ -62,6 +62,8 @@ type Agent struct { config *Config Handler Handler m mux + + stopCh chan struct{} } func NewAgent(config *Config) *Agent { @@ -70,6 +72,8 @@ func NewAgent(config *Config) *Agent { } return &Agent{ config: config, + + stopCh: make(chan struct{}), } } @@ -98,44 +102,69 @@ func (a *Agent) ServeConn(c net.Conn) error { ) defer putBuffer(b) for { - if p >= len(b) { - return errBufferOverflow - } - n, err := c.Read(b[p:]) - if err != nil { - return err - } - p += n - n = 0 - for n < p { - r, err := a.ServeTransport(b[n:p], c) - if err != nil { + select { + case <-a.stopCh: + // stop muxes + a.m.Close() + c.SetReadDeadline(time.Time{}) // reset read deadline + return nil + default: + if p >= len(b) { + return errBufferOverflow + } + c.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + n, err := c.Read(b[p:]) + if err, ok := err.(net.Error); ok && err.Timeout() { + continue + } else if err != nil { return err } - n += r - } - if n > 0 { - if n < p { - p = copy(b, b[n:p]) - } else { - p = 0 + p += n + n = 0 + for n < p { + r, err := a.ServeTransport(b[n:p], c) + if err != nil { + return err + } + n += r + } + if n > 0 { + if n < p { + p = copy(b, b[n:p]) + } else { + p = 0 + } } } } } +func (a *Agent) Stop() { + a.stopCh <- struct{}{} +} + func (a *Agent) ServePacket(c net.PacketConn) error { b := getBuffer() defer putBuffer(b) - defer c.Close() + // defer c.Close() for { - n, addr, err := c.ReadFrom(b) - if err != nil { - return err - } - if n > 0 { - a.ServeTransport(b[:n], &packetConn{c, addr}) + select { + case <-a.stopCh: + a.m.Close() + c.SetReadDeadline(time.Time{}) // reset read deadline + return nil + default: + c.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + n, addr, err := c.ReadFrom(b) + if err, ok := err.(net.Error); ok && err.Timeout() { + continue + } else if err != nil { + return err + } + if n > 0 { + a.ServeTransport(b[:n], &packetConn{c, addr}) + } } } } @@ -204,14 +233,17 @@ func (m *mux) serve(msg *Message, tr Transport) bool { m.RUnlock() if ok { tx.msg, tx.from = msg, tr - tx.Done() + tx.done() return true } return false } func (m *mux) newTx() *transaction { - tx := &transaction{id: NewTransaction()} + tx := &transaction{ + doneCh: make(chan struct{}), + id: NewTransaction(), + } m.Lock() if m.t == nil { m.t = make(map[string]*transaction) @@ -241,17 +273,25 @@ func (m *mux) Close() { } type transaction struct { - sync.WaitGroup + doneCh chan struct{} + id []byte from Transport msg *Message err error } +func (tx *transaction) done() { + select { + case tx.doneCh <- struct{}{}: + default: + return + } +} + func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) { - tx.Add(1) t := time.AfterFunc(d, tx.timeout) - tx.Wait() + <-tx.doneCh t.Stop() if err = tx.err; err != nil { return @@ -261,12 +301,12 @@ func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, e func (tx *transaction) timeout() { tx.err = errTimeout - tx.Done() + tx.done() } func (tx *transaction) Close() { tx.err = errCanceled - tx.Done() + tx.done() } var errCanceled = errors.New("stun: transaction canceled") diff --git a/stun/stun.go b/stun/stun.go index d239978..92e5f09 100644 --- a/stun/stun.go +++ b/stun/stun.go @@ -9,6 +9,26 @@ import ( "strings" ) +// DiscoverConn allows discovery using an existing net.UDPConn, which is +// temporarily hijacked to communicate with the STUN server. The net.UDPConn +// is returned upon completion of the discovery process, or closed if an error +// occurs. +func DiscoverConn(stunAddr string, c *net.UDPConn) (*net.UDPAddr, error) { + stunUDPAddr, err := net.ResolveUDPAddr("udp", stunAddr) + if err != nil { + return nil, err + } + conn := NewConn(&packetConn{c, stunUDPAddr}, nil) + + addr, err := conn.Discover() + if err != nil { + conn.Close() + return nil, err + } + conn.agent.Stop() // stop the agent's read loop, relinquishing the conn + return addr.(*net.UDPAddr), nil +} + func Discover(uri string) (net.PacketConn, net.Addr, error) { conn, err := Dial(uri, nil) if err != nil { @@ -19,7 +39,7 @@ func Discover(uri string) (net.PacketConn, net.Addr, error) { conn.Close() return nil, nil, err } - // TODO: hijack + conn.agent.Stop() // stop the agent's read loop return conn.Conn.(net.PacketConn), addr, nil } diff --git a/stun/stun_test.go b/stun/stun_test.go new file mode 100644 index 0000000..d71215d --- /dev/null +++ b/stun/stun_test.go @@ -0,0 +1,26 @@ +package stun + +import ( + "net" + "testing" + "time" +) + +func TestDiscoverConn(t *testing.T) { + config := DefaultConfig + config.RetransmissionTimeout = 300 * time.Millisecond + config.TransactionTimeout = time.Second + conn, err := net.ListenUDP("udp", nil) + if err != nil { + t.Fatal(err) + } + + addr, err := DiscoverConn("stun.l.google.com:19302", conn) + if err != nil { + t.Fatal(err) + } + + if addr == nil { + t.Fatal("addr not determined") + } +}