@@ -50,6 +50,7 @@ type ServerGuard struct {
50
50
IsSafeMode func (request * http.Request ) bool
51
51
Validate func (request * http.Request ) (* ServerGuard , error )
52
52
ShouldReturnRawResponse func (request * http.Request ) bool
53
+ RequestDataType func (request * http.Request ) string
53
54
54
55
ToCallbackType func (callbackHeader contract.EventInterface , buf []byte ) (decryptMessage interface {}, err error )
55
56
@@ -75,6 +76,9 @@ func NewServerGuard(app *ApplicationInterface) *ServerGuard {
75
76
serverGuard .ShouldReturnRawResponse = func (request * http.Request ) bool {
76
77
return serverGuard .shouldReturnRawResponse (request )
77
78
}
79
+ serverGuard .RequestDataType = func (request * http.Request ) string {
80
+ return serverGuard .requestDataType (request )
81
+ }
78
82
79
83
serverGuard .OverrideGetToken ()
80
84
serverGuard .OverrideResolve ()
@@ -150,15 +154,18 @@ func (serverGuard *ServerGuard) GetEvent(request *http.Request) (callback *model
150
154
if request == nil {
151
155
return nil , nil , errors .New ("request is invalid" )
152
156
}
153
- var b []byte = []byte ("<xml></xml> " )
157
+ var b []byte = []byte ("" )
154
158
if request .Body != http .NoBody {
155
159
b , err = io .ReadAll (request .Body )
156
160
if err != nil || b == nil {
157
161
return nil , nil , err
158
162
}
159
163
}
160
164
161
- callback , err = serverGuard .ParseMessage (string (b ))
165
+ // 请求数据类型
166
+ rDataType := serverGuard .RequestDataType (request )
167
+
168
+ callback , err = serverGuard .ParseMessage (string (b ), rDataType )
162
169
if err != nil {
163
170
return nil , nil , err
164
171
}
@@ -167,12 +174,15 @@ func (serverGuard *ServerGuard) GetEvent(request *http.Request) (callback *model
167
174
callbackHeader , err = serverGuard .DecryptEvent (request , string (b ))
168
175
} else {
169
176
callbackHeader = & models.CallbackMessageHeader {}
170
- err = xml .Unmarshal (b , callbackHeader )
177
+ if rDataType == messages .DataTypeXML {
178
+ err = xml .Unmarshal (b , callbackHeader )
179
+ } else {
180
+ err = json .Unmarshal (b , callbackHeader )
181
+ }
171
182
callbackHeader .Content = b
172
183
}
173
184
174
185
return callback , callbackHeader , err
175
-
176
186
}
177
187
178
188
func (serverGuard * ServerGuard ) GetMessage (request * http.Request ) (callback * models.Callback , callbackHeader * models.CallbackMessageHeader , Decrypted interface {}, err error ) {
@@ -184,7 +194,7 @@ func (serverGuard *ServerGuard) GetMessage(request *http.Request) (callback *mod
184
194
}
185
195
}
186
196
187
- callback , err = serverGuard .ParseMessage (string (b ))
197
+ callback , err = serverGuard .ParseMessage (string (b ), serverGuard . RequestDataType ( request ) )
188
198
if err != nil {
189
199
return nil , nil , nil , err
190
200
}
@@ -196,11 +206,9 @@ func (serverGuard *ServerGuard) GetMessage(request *http.Request) (callback *mod
196
206
Text : callback .Text ,
197
207
ToUserName : callback .ToUserName ,
198
208
}
199
-
200
209
}
201
210
202
211
return callback , callbackHeader , Decrypted , err
203
-
204
212
}
205
213
206
214
func (serverGuard * ServerGuard ) ResolveEvent (request * http.Request , closure func (event contract.EventInterface ) interface {}) (rs * http.Response , err error ) {
@@ -401,32 +409,39 @@ func (serverGuard *ServerGuard) signature(params []string) string {
401
409
}
402
410
403
411
func (serverGuard * ServerGuard ) isSafeMode (request * http.Request ) bool {
404
-
405
412
query := request .URL .Query ()
406
413
407
414
return query .Get ("signature" ) != "" && "aes" == query .Get ("encrypt_type" )
408
-
409
415
}
410
416
411
- func (serverGuard * ServerGuard ) ParseMessage (content string ) (callback * models.Callback , err error ) {
417
+ func (serverGuard * ServerGuard ) requestDataType (request * http.Request ) string {
418
+ if strings .HasPrefix (request .Header .Get ("Content-Type" ), "text/xml" ) ||
419
+ strings .HasPrefix (request .Header .Get ("Content-Type" ), "application/xml" ) {
420
+ // xml 格式
421
+ return messages .DataTypeXML
422
+ } else {
423
+ // json 格式
424
+ return messages .DataTypeJSON
425
+ }
426
+ }
412
427
428
+ func (serverGuard * ServerGuard ) ParseMessage (content string , dataType string ) (callback * models.Callback , err error ) {
413
429
callback = & models.Callback {}
414
430
415
- if len (content ) > 0 {
416
- if content [0 :1 ] == "<" {
417
- err = xml .Unmarshal ([]byte (content ), callback )
418
- if err != nil {
419
- return nil , err
420
- }
421
- } else {
422
- // Handle JSON format.
423
- err = object .JsonDecode ([]byte (content ), callback )
424
- if err != nil {
425
- return nil , err
426
- }
431
+ if len (content ) <= 0 {
432
+ return nil , errors .New ("request body is empty" )
433
+ }
434
+
435
+ if dataType == messages .DataTypeXML {
436
+ err = xml .Unmarshal ([]byte (content ), callback )
437
+ if err != nil {
438
+ return nil , err
427
439
}
428
440
} else {
429
-
441
+ err = object .JsonDecode ([]byte (content ), callback )
442
+ if err != nil {
443
+ return nil , err
444
+ }
430
445
}
431
446
432
447
return callback , err
@@ -477,7 +492,11 @@ func (serverGuard *ServerGuard) DecryptEvent(request *http.Request, content stri
477
492
}
478
493
479
494
callbackHeader = & models.CallbackMessageHeader {}
480
- err = xml .Unmarshal (buf , callbackHeader )
495
+ if serverGuard .RequestDataType (request ) == messages .DataTypeXML {
496
+ err = xml .Unmarshal (buf , callbackHeader )
497
+ } else {
498
+ err = json .Unmarshal (buf , callbackHeader )
499
+ }
481
500
if err != nil {
482
501
return nil , err
483
502
}
@@ -503,7 +522,11 @@ func (serverGuard *ServerGuard) decryptMessage(request *http.Request, content st
503
522
}
504
523
505
524
callbackHeader = & models.CallbackMessageHeader {}
506
- err = xml .Unmarshal (buf , callbackHeader )
525
+ if serverGuard .RequestDataType (request ) == messages .DataTypeXML {
526
+ err = xml .Unmarshal (buf , callbackHeader )
527
+ } else {
528
+ err = json .Unmarshal (buf , callbackHeader )
529
+ }
507
530
if err != nil {
508
531
return nil , nil , err
509
532
}
0 commit comments