Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow connections returned by go-stun to be used (for hole-punching, etc). #8

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 73 additions & 33 deletions stun/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ type Agent struct {
config *Config
Handler Handler
m mux

stopCh chan struct{}
}

func NewAgent(config *Config) *Agent {
Expand All @@ -70,6 +72,8 @@ func NewAgent(config *Config) *Agent {
}
return &Agent{
config: config,

stopCh: make(chan struct{}),
}
}

Expand Down Expand Up @@ -98,44 +102,69 @@ func (a *Agent) ServeConn(c net.Conn) error {
)
defer putBuffer(b)
for {
if p >= len(b) {
return errBufferOverflow
}
n, err := c.Read(b[p:])
if err != nil {
return err
}
p += n
n = 0
for n < p {
r, err := a.ServeTransport(b[n:p], c)
if err != nil {
select {
case <-a.stopCh:
// stop muxes
a.m.Close()
c.SetReadDeadline(time.Time{}) // reset read deadline
return nil
default:
if p >= len(b) {
return errBufferOverflow
}
c.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
n, err := c.Read(b[p:])
if err, ok := err.(net.Error); ok && err.Timeout() {
continue
} else if err != nil {
return err
}
n += r
}
if n > 0 {
if n < p {
p = copy(b, b[n:p])
} else {
p = 0
p += n
n = 0
for n < p {
r, err := a.ServeTransport(b[n:p], c)
if err != nil {
return err
}
n += r
}
if n > 0 {
if n < p {
p = copy(b, b[n:p])
} else {
p = 0
}
}
}
}
}

func (a *Agent) Stop() {
a.stopCh <- struct{}{}
}

func (a *Agent) ServePacket(c net.PacketConn) error {
b := getBuffer()
defer putBuffer(b)
defer c.Close()
// defer c.Close()

for {
n, addr, err := c.ReadFrom(b)
if err != nil {
return err
}
if n > 0 {
a.ServeTransport(b[:n], &packetConn{c, addr})
select {
case <-a.stopCh:
a.m.Close()
c.SetReadDeadline(time.Time{}) // reset read deadline
return nil
default:
c.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
n, addr, err := c.ReadFrom(b)
if err, ok := err.(net.Error); ok && err.Timeout() {
continue
} else if err != nil {
return err
}
if n > 0 {
a.ServeTransport(b[:n], &packetConn{c, addr})
}
}
}
}
Expand Down Expand Up @@ -204,14 +233,17 @@ func (m *mux) serve(msg *Message, tr Transport) bool {
m.RUnlock()
if ok {
tx.msg, tx.from = msg, tr
tx.Done()
tx.done()
return true
}
return false
}

func (m *mux) newTx() *transaction {
tx := &transaction{id: NewTransaction()}
tx := &transaction{
doneCh: make(chan struct{}),
id: NewTransaction(),
}
m.Lock()
if m.t == nil {
m.t = make(map[string]*transaction)
Expand Down Expand Up @@ -241,17 +273,25 @@ func (m *mux) Close() {
}

type transaction struct {
sync.WaitGroup
doneCh chan struct{}

id []byte
from Transport
msg *Message
err error
}

func (tx *transaction) done() {
select {
case tx.doneCh <- struct{}{}:
default:
return
}
}

func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, err error) {
tx.Add(1)
t := time.AfterFunc(d, tx.timeout)
tx.Wait()
<-tx.doneCh
t.Stop()
if err = tx.err; err != nil {
return
Expand All @@ -261,12 +301,12 @@ func (tx *transaction) Receive(d time.Duration) (msg *Message, from Transport, e

func (tx *transaction) timeout() {
tx.err = errTimeout
tx.Done()
tx.done()
}

func (tx *transaction) Close() {
tx.err = errCanceled
tx.Done()
tx.done()
}

var errCanceled = errors.New("stun: transaction canceled")
Expand Down
22 changes: 21 additions & 1 deletion stun/stun.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@ import (
"strings"
)

// DiscoverConn allows discovery using an existing net.UDPConn, which is
// temporarily hijacked to communicate with the STUN server. The net.UDPConn
// is returned upon completion of the discovery process, or closed if an error
// occurs.
func DiscoverConn(stunAddr string, c *net.UDPConn) (*net.UDPAddr, error) {
stunUDPAddr, err := net.ResolveUDPAddr("udp", stunAddr)
if err != nil {
return nil, err
}
conn := NewConn(&packetConn{c, stunUDPAddr}, nil)

addr, err := conn.Discover()
if err != nil {
conn.Close()
return nil, err
}
conn.agent.Stop() // stop the agent's read loop, relinquishing the conn
return addr.(*net.UDPAddr), nil
}

func Discover(uri string) (net.PacketConn, net.Addr, error) {
conn, err := Dial(uri, nil)
if err != nil {
Expand All @@ -19,7 +39,7 @@ func Discover(uri string) (net.PacketConn, net.Addr, error) {
conn.Close()
return nil, nil, err
}
// TODO: hijack
conn.agent.Stop() // stop the agent's read loop
return conn.Conn.(net.PacketConn), addr, nil
}

Expand Down
26 changes: 26 additions & 0 deletions stun/stun_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package stun

import (
"net"
"testing"
"time"
)

func TestDiscoverConn(t *testing.T) {
config := DefaultConfig
config.RetransmissionTimeout = 300 * time.Millisecond
config.TransactionTimeout = time.Second
conn, err := net.ListenUDP("udp", nil)
if err != nil {
t.Fatal(err)
}

addr, err := DiscoverConn("stun.l.google.com:19302", conn)
if err != nil {
t.Fatal(err)
}

if addr == nil {
t.Fatal("addr not determined")
}
}