Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 90 additions & 208 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,16 @@ import (
"context"
"crypto/rsa"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/netip"
"net/url"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/Masterminds/semver"
dockerclient "github.com/docker/docker/client"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/shellhub-io/shellhub/agent/pkg/keygen"
"github.com/shellhub-io/shellhub/agent/pkg/sysinfo"
Expand Down Expand Up @@ -112,6 +106,10 @@ type Config struct {
// MaxRetryConnectionTimeout specifies the maximum time, in seconds, that an agent will wait
// before attempting to reconnect to the ShellHub server. Default is 60 seconds.
MaxRetryConnectionTimeout int `env:"MAX_RETRY_CONNECTION_TIMEOUT,default=60" validate:"min=10,max=120"`

// ConnectionVersion specifies the version of the connection protocol to use.
// Supported values are 1 and 2. Default is 1.
ConnectionVersion int `env:"CONNECTION_VERSION,default=1"`
}

func LoadConfigFromEnv() (*Config, map[string]interface{}, error) {
Expand Down Expand Up @@ -161,10 +159,13 @@ type Agent struct {
cli client.Client
serverInfo *models.Info
server *server.Server
tunnel *tunnel.Tunnel
listening chan bool
closed atomic.Bool
mode Mode
// listener is the current connection to the server.
listener atomic.Pointer[net.Listener]
// logger is the agent's logger instance.
logger *log.Entry
}

// NewAgent creates a new agent instance, requiring the ShellHub server's address to connect to, the namespace's tenant
Expand Down Expand Up @@ -255,6 +256,16 @@ func (a *Agent) Initialize() error {

a.closed.Store(false)

a.logger = log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
"ssh_endpoint": a.serverInfo.Endpoints.SSH,
"api_endpoint": a.serverInfo.Endpoints.API,
"connection_version": a.config.ConnectionVersion,
"sshid": fmt.Sprintf("%s.%s@%s", a.authData.Namespace, a.authData.Name, strings.Split(a.serverInfo.Endpoints.SSH, ":")[0]),
})

return nil
}

Expand Down Expand Up @@ -370,261 +381,132 @@ func (a *Agent) isClosed() bool {
func (a *Agent) Close() error {
a.closed.Store(true)

return a.tunnel.Close()
}

func sshHandler(serv *server.Server) func(c echo.Context) error {
return func(c echo.Context) error {
hj, ok := c.Response().Writer.(http.Hijacker)
if !ok {
return c.String(http.StatusInternalServerError, "webserver doesn't support hijacking")
}
l := a.listener.Load()
if l == nil {
return nil
}

conn, _, err := hj.Hijack()
if err != nil {
return c.String(http.StatusInternalServerError, "failed to hijack connection")
}
return (*l).Close()
}

id := c.Param("id")
httpConn := c.Request().Context().Value("http-conn").(net.Conn)
serv.Sessions.Store(id, httpConn)
serv.HandleConn(httpConn)
const (
ConnectionV1 = 1
ConnectionV2 = 2
)

conn.Close()
func (a *Agent) Listen(ctx context.Context) error {
a.mode.Serve(a)

return nil
switch a.config.ConnectionVersion {
case ConnectionV1:
return a.listenV1(ctx)
case ConnectionV2:
return a.listenV2(ctx)
default:
return fmt.Errorf("unsupported connection version: %d", a.config.ConnectionVersion)
}
}

// httpProxyHandler handlers proxy connections to the required address.
func httpProxyHandler(agent *Agent) func(c echo.Context) error {
const ProxyHandlerNetwork = "tcp"
func (a *Agent) listenV1(ctx context.Context) error {
tun := tunnel.NewTunnelV1()

return func(c echo.Context) error {
logger := log.WithFields(log.Fields{
"remote": c.Request().RemoteAddr,
"namespace": c.Request().Header.Get("X-Namespace"),
"path": c.Request().Header.Get("X-Path"),
"version": AgentVersion,
})
tun.Handle(HandleSSHOpenV1, sshHandlerV1(a))
tun.Handle(HandleSSHCloseV1, sshCloseHandlerV1(a))
tun.Handle(HandleHTTPProxyV1, httpProxyHandlerV1(a))

errorResponse := func(err error, msg string, code int) error {
logger.WithError(err).Debug(msg)
go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck

return c.String(code, msg)
}
ctx, cancel := context.WithCancel(ctx)
go func() {
for {
if a.isClosed() {
a.logger.Info("Stopped listening for connections")

host, port, err := net.SplitHostPort(c.Param("addr"))
if err != nil {
return errorResponse(err, "failed because address is invalid", http.StatusInternalServerError)
}
cancel()

if _, ok := agent.mode.(*ConnectorMode); ok {
cli, err := dockerclient.NewClientWithOpts(dockerclient.FromEnv, dockerclient.WithAPIVersionNegotiation())
if err != nil {
return errorResponse(err, "failed to connect to the Docker Engine", http.StatusInternalServerError)
return
}

container, err := cli.ContainerInspect(context.Background(), agent.server.ContainerID)
if err != nil {
return errorResponse(err, "failed to inspect the container", http.StatusInternalServerError)
}
ShellHubConnectV1Path := "/ssh/connection"

var target string
a.logger.Debug("Using tunnel version 1")

addr, err := netip.ParseAddr(host)
listener, err := a.cli.NewReverseListenerV1(
ctx,
a.authData.Token,
ShellHubConnectV1Path,
)
if err != nil {
return errorResponse(err, "failed to parse the for lookback checkage", http.StatusInternalServerError)
}

if addr.IsLoopback() {
for _, network := range container.NetworkSettings.Networks {
target = network.IPAddress

break
}
} else {
for _, network := range container.NetworkSettings.Networks {
subnet, err := netip.ParsePrefix(fmt.Sprintf("%s/%d", network.Gateway, network.IPPrefixLen))
if err != nil {
logger.WithError(err).Trace("Failed to parse the gateway on proxy")

continue
}

ip, err := netip.ParseAddr(host)
if err != nil {
logger.WithError(err).Trace("Failed to parse the address on proxy")
a.logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")

continue
}

if subnet.Contains(ip) {
target = ip.String()

break
}
}
}
time.Sleep(time.Second * 10)

if target == "" {
return errorResponse(nil, "address not found on the device", http.StatusInternalServerError)
continue
}
a.listener.Store(&listener)

host = target
}

// NOTE: Gets the to address to connect to. This address can be just a port, :8080, or the host and port,
// localhost:8080.
addr := fmt.Sprintf("%s:%s", host, port)
a.logger.Info("Server connection established")

in, err := net.Dial(ProxyHandlerNetwork, addr)
if err != nil {
return errorResponse(err, "failed to connect to the server on device", http.StatusInternalServerError)
}

defer in.Close()
a.listening <- true

// NOTE: Inform to the connection that the dial was successfully.
if err := c.NoContent(http.StatusOK); err != nil {
return errorResponse(err, "failed to send the ok status code back to server", http.StatusInternalServerError)
}
if err := tun.Listen(ctx, listener); err != nil {
a.logger.WithError(err).Error("Tunnel listener exited with error")
}

// NOTE: Hijacks the connection to control the data transferred to the client connected. This way, we don't
// depend upon anything externally, only the data.
out, _, err := c.Response().Hijack()
if err != nil {
return errorResponse(err, "failed to hijack connection", http.StatusInternalServerError)
a.listening <- false
}
}()

defer out.Close() // nolint:errcheck

wg := new(sync.WaitGroup)
done := sync.OnceFunc(func() {
defer in.Close()
defer out.Close()

logger.Trace("close called on in and out connections")
})

wg.Add(1)
go func() {
defer done()
defer wg.Done()

io.Copy(in, out) //nolint:errcheck
}()

wg.Add(1)
go func() {
defer done()
defer wg.Done()

io.Copy(out, in) //nolint:errcheck
}()

logger.WithError(err).Trace("proxy handler waiting for data pipe")
wg.Wait()

logger.WithError(err).Trace("proxy handler done")

return nil
}
}

func sshCloseHandler(a *Agent, serv *server.Server) func(c echo.Context) error {
return func(c echo.Context) error {
id := c.Param("id")
serv.CloseSession(id)

log.WithFields(
log.Fields{
"id": id,
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
},
).Info("A tunnel connection was closed")
<-ctx.Done()

return nil
}
return a.Close()
}

// Listen creates the SSH server and listening for connections.
func (a *Agent) Listen(ctx context.Context) error {
a.mode.Serve(a)
func (a *Agent) listenV2(ctx context.Context) error {
tun := tunnel.NewTunnelV2(a.cli)

a.tunnel = tunnel.NewBuilder().
WithSSHHandler(sshHandler(a.server)).
WithSSHCloseHandler(sshCloseHandler(a, a.server)).
WithHTTPProxyHandler(httpProxyHandler(a)).
Build()
tun.Handle(HandleSSHOpenV2, sshHandlerV2(a))
tun.Handle(HandleSSHCloseV2, sshCloseHandlerV2(a))
tun.Handle(HandleHTTPProxyV2, httpProxyHandlerV2(a))

go a.ping(ctx, AgentPingDefaultInterval) //nolint:errcheck

ctx, cancel := context.WithCancel(ctx)
go func() {
for {
if a.isClosed() {
log.WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
}).Info("Stopped listening for connections")
a.logger.Info("Stopped listening for connections")

cancel()

return
}

namespace := a.authData.Namespace
tenantName := a.authData.Name
sshEndpoint := a.serverInfo.Endpoints.SSH
ShellHubConnectV2Path := "/agent/connection"

sshid := strings.NewReplacer(
"{namespace}", namespace,
"{tenantName}", tenantName,
"{sshEndpoint}", strings.Split(sshEndpoint, ":")[0],
).Replace("{namespace}.{tenantName}@{sshEndpoint}")
a.logger.Debug("Using tunnel version 2")

listener, err := a.cli.NewReverseListener(ctx, a.authData.Token, "/ssh/connection")
listener, err := a.cli.NewReverseListenerV2(
ctx,
a.authData.Token,
ShellHubConnectV2Path,
client.NewReverseV2ConfigFromMap(a.authData.Config),
)
if err != nil {
log.WithError(err).WithFields(log.Fields{
"version": AgentVersion,
"tenant_id": a.authData.Namespace,
"server_address": a.config.ServerAddress,
"ssh_server": sshEndpoint,
"sshid": sshid,
}).Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")
a.logger.Error("Failed to connect to server through reverse tunnel. Retry in 10 seconds")

time.Sleep(time.Second * 10)

continue
}
a.listener.Store(&listener)

log.WithFields(log.Fields{
"namespace": namespace,
"hostname": tenantName,
"server_address": a.config.ServerAddress,
"ssh_server": sshEndpoint,
"sshid": sshid,
}).Info("Server connection established")
a.logger.Info("Server connection established")

a.listening <- true

{
// NOTE: Tunnel'll only realize that it lost its connection to the ShellHub SSH when the next
// "keep-alive" connection fails. As a result, it will take this interval to reconnect to its server.
err := a.tunnel.Listen(listener)

log.WithError(err).WithFields(log.Fields{
"namespace": namespace,
"hostname": tenantName,
"server_address": a.config.ServerAddress,
"ssh_server": sshEndpoint,
"sshid": sshid,
}).Info("Tunnel listener closed")

listener.Close() // nolint:errcheck
if err := tun.Listen(ctx, listener); err != nil {
a.logger.WithError(err).Error("Tunnel listener exited with error")
}

a.listening <- false
Expand Down
Loading