Skip to content

fix: improve SAML signature validation for redirect binding #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 168 additions & 33 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,27 +306,21 @@ func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.UR
// We can't depend on Query().set() as order matters for signing
query := rv.RawQuery
if len(query) > 0 {
query += "&SAMLRequest=" + url.QueryEscape(requestStr.String())
query += "&" + string(SAMLRequest) + "=" + url.QueryEscape(requestStr.String())
} else {
query += "SAMLRequest=" + url.QueryEscape(requestStr.String())
query += string(SAMLRequest) + "=" + url.QueryEscape(requestStr.String())
}

if relayState != "" {
query += "&RelayState=" + relayState
}
if len(sp.SignatureMethod) > 0 {
query += "&SigAlg=" + url.QueryEscape(sp.SignatureMethod)
signingContext, err := GetSigningContext(sp)

if err != nil {
return nil, err
var errSig error
query, errSig = sp.signQuery(SAMLRequest, query, requestStr.String(), relayState)
if errSig != nil {
return nil, errSig
}

sig, err := signingContext.SignString(query)
if err != nil {
return nil, err
}
query += "&Signature=" + url.QueryEscape(base64.StdEncoding.EncodeToString(sig))
}

rv.RawQuery = query
Expand Down Expand Up @@ -1364,7 +1358,6 @@ func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error {

// MakeLogoutRequest produces a new LogoutRequest object for idpURL.
func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequest, error) {

req := LogoutRequest{
ID: fmt.Sprintf("id-%x", randomBytes(20)),
IssueInstant: TimeNow(),
Expand All @@ -1381,11 +1374,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ
SPNameQualifier: sp.Metadata().EntityID,
},
}
if sp.SignatureMethod != "" {
if err := sp.SignLogoutRequest(&req); err != nil {
return nil, err
}
}

return &req, nil
}

Expand All @@ -1397,11 +1386,12 @@ func (sp *ServiceProvider) MakeRedirectLogoutRequest(nameID, relayState string)
if err != nil {
return nil, err
}
return req.Redirect(relayState), nil

return req.Redirect(relayState, sp)
}

// Redirect returns a URL suitable for using the redirect binding with the request
func (r *LogoutRequest) Redirect(relayState string) *url.URL {
func (r *LogoutRequest) Redirect(relayState string, sp *ServiceProvider) (*url.URL, error) {
w := &bytes.Buffer{}
w1 := base64.NewEncoder(base64.StdEncoding, w)
w2, _ := flate.NewWriter(w1, 9)
Expand All @@ -1419,14 +1409,29 @@ func (r *LogoutRequest) Redirect(relayState string) *url.URL {

rv, _ := url.Parse(r.Destination)

query := rv.Query()
query.Set("SAMLRequest", w.String())
// We can't depend on Query().set() as order matters for signing
query := rv.RawQuery
if len(query) > 0 {
query += "&" + string(SAMLRequest) + "=" + url.QueryEscape(w.String())
} else {
query += string(SAMLRequest) + "=" + url.QueryEscape(w.String())
}

if relayState != "" {
query.Set("RelayState", relayState)
query += "&RelayState=" + relayState
}
rv.RawQuery = query.Encode()

return rv
if sp.SignatureMethod != "" {
var err error
query, err = sp.signQuery(SAMLRequest, query, w.String(), relayState)
if err != nil {
return nil, fmt.Errorf("logout request - redirect binding - failed to sign query: %v", err)
}
}

rv.RawQuery = query

return rv, nil
}

// MakePostLogoutRequest creates a SAML authentication request using
Expand All @@ -1437,6 +1442,13 @@ func (sp *ServiceProvider) MakePostLogoutRequest(nameID, relayState string) ([]b
if err != nil {
return nil, err
}

if sp.SignatureMethod != "" {
if err := sp.SignLogoutRequest(req); err != nil {
return nil, err
}
}

return req.Post(relayState), nil
}

Expand Down Expand Up @@ -1624,8 +1636,9 @@ func (sp *ServiceProvider) nameIDFormat() string {

// ValidateLogoutResponseRequest validates the LogoutResponse content from the request
func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) error {
if data := req.URL.Query().Get("SAMLResponse"); data != "" {
return sp.ValidateLogoutResponseRedirect(data)
query := req.URL.Query()
if data := query.Get("SAMLResponse"); data != "" {
return sp.ValidateLogoutResponseRedirect(req)
}

err := req.ParseForm()
Expand Down Expand Up @@ -1677,7 +1690,9 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error
//
// URL Binding appears to be gzip / flate encoded
// See https://www.oasis-open.org/committees/download.php/20645/sstc-saml-tech-overview-2%200-draft-10.pdf 6.6
func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData string) error {
func (sp *ServiceProvider) ValidateLogoutResponseRedirect(r *http.Request) error {
query := r.URL.Query()
queryParameterData := query.Get("SAMLResponse")
retErr := &InvalidResponseError{
Now: TimeNow(),
}
Expand All @@ -1699,13 +1714,15 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
return err
}

doc := etree.NewDocument()
if err := doc.ReadFromBytes(gr); err != nil {
retErr.PrivateErr = err
return retErr
if query.Get("Signature") != "" && query.Get("SigAlg") != "" {
if err := sp.validateRedirectBindingSignature(r); err != nil {
retErr.PrivateErr = err
return retErr
}
}

if err := sp.validateSignature(doc.Root()); err != nil {
doc := etree.NewDocument()
if err := doc.ReadFromBytes(gr); err != nil {
retErr.PrivateErr = err
return retErr
}
Expand All @@ -1715,6 +1732,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
retErr.PrivateErr = err
return retErr
}

return sp.validateLogoutResponse(&resp)
}

Expand All @@ -1738,6 +1756,123 @@ func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error {
return nil
}

// ValidateLogoutRequest validates the LogoutRequest content from the request
func (sp *ServiceProvider) ValidateLogoutRequest(req *http.Request) (*LogoutRequest, error) {
query := req.URL.Query()
if data := query.Get(string(SAMLRequest)); data != "" {
return sp.ValidateLogoutRequestRedirect(req)
}

err := req.ParseForm()
if err != nil {
return nil, fmt.Errorf("validateLogoutRequest: unable to parse form: %v", err)
}

return sp.ValidateLogoutRequestForm(req.PostForm.Get(string(SAMLRequest)))
}

// ValidateLogoutRequestRedirect returns a nil error if the logout request is valid. This is used for the HTTP Redirect binding.
func (sp *ServiceProvider) ValidateLogoutRequestRedirect(r *http.Request) (*LogoutRequest, error) {
query := r.URL.Query()
queryParameterData := query.Get(string(SAMLRequest))
retErr := &InvalidResponseError{
Now: TimeNow(),
}

rawRequestBuf, err := base64.StdEncoding.DecodeString(queryParameterData)
if err != nil {
retErr.PrivateErr = fmt.Errorf("validateLogoutRequestRedirect: unable to parse base64: %s", err)
return nil, retErr
}
retErr.Response = string(rawRequestBuf)

gr, err := io.ReadAll(newSaferFlateReader(bytes.NewBuffer(rawRequestBuf)))
if err != nil {
retErr.PrivateErr = err
return nil, retErr
}

if err := xrv.Validate(bytes.NewReader(gr)); err != nil {
return nil, fmt.Errorf("validateLogoutRequestRedirect: response contains invalid XML: %s", err)
}

if query.Get("Signature") != "" && query.Get("SigAlg") != "" {
if err := sp.validateRedirectBindingSignature(r); err != nil {
retErr.PrivateErr = err
return nil, retErr
}
}

doc := etree.NewDocument()
if err := doc.ReadFromBytes(gr); err != nil {
retErr.PrivateErr = err
return nil, retErr
}

var req LogoutRequest
if err := unmarshalElement(doc.Root(), &req); err != nil {
retErr.PrivateErr = err
return nil, retErr
}

return &req, sp.validateLogoutRequest(&req)
}

// ValidateLogoutRequestForm returns a nil error if the logout request is valid. This is used for the HTTP POST binding.
func (sp *ServiceProvider) ValidateLogoutRequestForm(postFormData string) (*LogoutRequest, error) {
retErr := &InvalidResponseError{
Now: TimeNow(),
}

rawRequestBuf, err := base64.StdEncoding.DecodeString(postFormData)
if err != nil {
retErr.PrivateErr = fmt.Errorf("unable to parse base64: %s", err)
return nil, retErr
}
retErr.Response = string(rawRequestBuf)

if err := xrv.Validate(bytes.NewReader(rawRequestBuf)); err != nil {
return nil, fmt.Errorf("logout request contains invalid XML: %s", err)
}

doc := etree.NewDocument()
if err := doc.ReadFromBytes(rawRequestBuf); err != nil {
retErr.PrivateErr = err
return nil, retErr
}

if err := sp.validateSignature(doc.Root()); err != nil {
retErr.PrivateErr = err
return nil, retErr
}

var req LogoutRequest
if err := unmarshalElement(doc.Root(), &req); err != nil {
retErr.PrivateErr = err
return nil, retErr
}

return &req, sp.validateLogoutRequest(&req)
}

// validateLogoutRequest validates the LogoutRequest fields. Returns a nil error if the LogoutRequest is valid.
// This checks the destination, issue instant, and issuer.
func (sp *ServiceProvider) validateLogoutRequest(req *LogoutRequest) error {
if req.Destination != sp.SloURL.String() {
return fmt.Errorf("`Destination` does not match SloURL (expected %q)", sp.SloURL.String())
}

now := time.Now()
if req.IssueInstant.Add(MaxIssueDelay).Before(now) {
return fmt.Errorf("issueInstant expired at %s", req.IssueInstant.Add(MaxIssueDelay))
}
if req.Issuer.Value != sp.IDPMetadata.EntityID {
return fmt.Errorf("issuer does not match the IDP metadata (expected %q)", sp.IDPMetadata.EntityID)
}

return nil
}

func firstSet(a, b string) string {
if a == "" {
return b
Expand Down
Loading