Skip to content

Commit dc9dcd7

Browse files
committed
refactor(ssh): use context control on authentication helper functions
1 parent d062d2b commit dc9dcd7

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

ssh/session/session.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -386,9 +386,9 @@ func (s *Session) NewAgentChannel(name string, seat int) (*AgentChannel, error)
386386
return a, nil
387387
}
388388

389-
func (s *Session) checkFirewall() (bool, error) {
389+
func (s *Session) checkFirewall(ctx context.Context) (bool, error) {
390390
// TODO: Refactor firewall evaluation to remove the map requirement.
391-
if err := s.api.FirewallEvaluate(context.TODO(), map[string]string{
391+
if err := s.api.FirewallEvaluate(ctx, map[string]string{
392392
"domain": s.Namespace.Name,
393393
"name": s.Device.Name,
394394
"username": s.Target.Username,
@@ -417,8 +417,8 @@ func (s *Session) checkFirewall() (bool, error) {
417417
return true, nil
418418
}
419419

420-
func (s *Session) checkBilling() (bool, error) {
421-
device, err := s.api.GetDevice(context.TODO(), s.Device.UID)
420+
func (s *Session) checkBilling(ctx context.Context) (bool, error) {
421+
device, err := s.api.GetDevice(ctx, s.Device.UID)
422422
if err != nil {
423423
defer log.WithError(err).WithFields(log.Fields{
424424
"uid": s.UID,
@@ -446,8 +446,8 @@ func (s *Session) checkBilling() (bool, error) {
446446
}
447447

448448
// registerAPISession registers a new session on the API.
449-
func (s *Session) register() error {
450-
err := s.api.SessionCreate(context.TODO(), requests.SessionCreate{
449+
func (s *Session) register(ctx context.Context) error {
450+
err := s.api.SessionCreate(ctx, requests.SessionCreate{
451451
UID: s.UID,
452452
DeviceUID: s.Device.UID,
453453
Username: s.Target.Username,
@@ -469,10 +469,10 @@ func (s *Session) register() error {
469469
// Authenticate marks the session as authenticated on the API.
470470
//
471471
// It returns an error if authentication fails.
472-
func (s *Session) authenticate() error {
472+
func (s *Session) authenticate(ctx context.Context) error {
473473
value := true
474474

475-
return s.api.UpdateSession(context.TODO(), s.UID, &models.SessionUpdate{
475+
return s.api.UpdateSession(ctx, s.UID, &models.SessionUpdate{
476476
Authenticated: &value,
477477
})
478478
}
@@ -580,12 +580,12 @@ func (s *Session) Evaluate(ctx gliderssh.Context) error {
580580
snap := getSnapshot(ctx)
581581

582582
if envs.IsEnterprise() {
583-
if ok, err := s.checkFirewall(); err != nil || !ok {
583+
if ok, err := s.checkFirewall(ctx); err != nil || !ok {
584584
return err
585585
}
586586

587587
if envs.IsCloud() {
588-
if ok, err := s.checkBilling(); err != nil || !ok {
588+
if ok, err := s.checkBilling(ctx); err != nil || !ok {
589589
return err
590590
}
591591
}
@@ -624,7 +624,7 @@ func (s *Session) Auth(ctx gliderssh.Context, auth Auth) error {
624624
return err
625625
}
626626

627-
if err := sess.register(); err != nil {
627+
if err := sess.register(ctx); err != nil {
628628
return err
629629
}
630630

@@ -636,7 +636,7 @@ func (s *Session) Auth(ctx gliderssh.Context, auth Auth) error {
636636
return err
637637
}
638638

639-
if err := sess.authenticate(); err != nil {
639+
if err := sess.authenticate(ctx); err != nil {
640640
return err
641641
}
642642
default:

0 commit comments

Comments
 (0)