wechat_pay_cipher.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. // Copyright 2021 Tencent Inc. All rights reserved.
  2. package ciphers
  3. import (
  4. "context"
  5. "fmt"
  6. "reflect"
  7. "git.nanodreamtech.com/sg/wechatpay-go/core/cipher"
  8. )
  9. type cipherType string
  10. const (
  11. cipherTypeEncrypt cipherType = "encrypt"
  12. cipherTypeDecrypt cipherType = "decrypt"
  13. )
  14. const (
  15. fieldTagEncryption = "encryption"
  16. encryptionTypeAPIV3 = "EM_APIV3"
  17. )
  18. // fieldCipherFuncType 用于对特定类型字段进行加/解密的方法类型
  19. type fieldCipherFuncType func(*WechatPayCipher, context.Context, cipherType, reflect.StructField, reflect.Value) error
  20. // fieldCipherMap 对不同类型字段进行加/解密的方法字典
  21. var fieldCipherMap map[reflect.Kind]fieldCipherFuncType
  22. func init() {
  23. // 初始化加/解密方法字典,使用 init 初始化而不是直接在声明时初始化的原因是为了避免初始化循环依赖
  24. fieldCipherMap = map[reflect.Kind]fieldCipherFuncType{
  25. reflect.Struct: (*WechatPayCipher).cipherStructField,
  26. reflect.Array: (*WechatPayCipher).cipherArrayField,
  27. reflect.Slice: (*WechatPayCipher).cipherArrayField,
  28. reflect.String: (*WechatPayCipher).cipherStringField,
  29. }
  30. }
  31. // WechatPayCipher 提供微信支付敏感信息加解密功能
  32. //
  33. // 为了保证通信过程中敏感信息字段(如用户的住址、银行卡号、手机号码等)的机密性,微信支付API v3要求:
  34. // 1. 商户对上送的敏感信息字段进行加密
  35. // 2. 微信支付对下行的敏感信息字段进行加密
  36. //
  37. // 详见:https://wechatpay-api.gitbook.io/wechatpay-api-v3/qian-ming-zhi-nan-1/min-gan-xin-xi-jia-mi
  38. type WechatPayCipher struct {
  39. encryptor cipher.Encryptor
  40. decryptor cipher.Decryptor
  41. }
  42. // Encrypt 对结构中的敏感字段进行加密
  43. func (c *WechatPayCipher) Encrypt(ctx context.Context, in interface{}) (string, error) {
  44. serial, err := c.encryptor.SelectCertificate(ctx)
  45. if err != nil {
  46. return "", err
  47. }
  48. ctx = setEncryptSerial(ctx, serial)
  49. if v, ok := in.(reflect.Value); ok {
  50. err = c.cipher(ctx, cipherTypeEncrypt, v)
  51. } else {
  52. err = c.cipher(ctx, cipherTypeEncrypt, reflect.ValueOf(in))
  53. }
  54. if err != nil {
  55. return "", fmt.Errorf("encrypt struct failed: %w", err)
  56. }
  57. return serial, nil
  58. }
  59. // Decrypt 对结构中的敏感字段进行解密
  60. func (c *WechatPayCipher) Decrypt(ctx context.Context, in interface{}) error {
  61. var err error
  62. if v, ok := in.(reflect.Value); ok {
  63. err = c.cipher(ctx, cipherTypeDecrypt, v)
  64. } else {
  65. err = c.cipher(ctx, cipherTypeDecrypt, reflect.ValueOf(in))
  66. }
  67. if err != nil {
  68. return fmt.Errorf("decrypt struct failed: %w", err)
  69. }
  70. return nil
  71. }
  72. // cipher 执行加/解密的入口函数
  73. func (c *WechatPayCipher) cipher(ctx context.Context, ty cipherType, v reflect.Value) error {
  74. var isNil bool
  75. if v, isNil = derefPtrValue(v); isNil {
  76. // No cipher required for nil ptr
  77. return nil
  78. }
  79. if !v.CanSet() {
  80. return fmt.Errorf("in-place cipher requires settable input, ptr for example")
  81. }
  82. if v.Type().Kind() != reflect.Struct {
  83. return fmt.Errorf("only struct can be ciphered")
  84. }
  85. return c.cipherStruct(ctx, ty, v)
  86. }
  87. // cipherStruct 递归进行Struct的加/解密操作
  88. func (c *WechatPayCipher) cipherStruct(ctx context.Context, ty cipherType, v reflect.Value) error {
  89. var t = v.Type()
  90. for i := 0; i < t.NumField(); i++ {
  91. field := t.Field(i)
  92. fieldValue := v.Field(i)
  93. if err := c.cipherField(ctx, ty, field, fieldValue); err != nil {
  94. return err
  95. }
  96. }
  97. return nil
  98. }
  99. // derefPtrValue 将 Ptr 类型的 Value 解引用,直到获得非指针内容。
  100. //
  101. // 如果输入的 Value 不是 Ptr,则返回其本身
  102. // 如果输入的 Value 最终指向 Nil,则返回最终指向 Nil 的 Value 对象,且 isNil 为 true
  103. func derefPtrValue(inValue reflect.Value) (outValue reflect.Value, isNil bool) {
  104. v := inValue
  105. for v.Type().Kind() == reflect.Ptr {
  106. if v.IsNil() {
  107. return v, true
  108. }
  109. v = v.Elem()
  110. }
  111. return v, false
  112. }
  113. // cipherField 对字段进行加/解密
  114. func (c *WechatPayCipher) cipherField(
  115. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  116. ) error {
  117. if !fieldValue.CanInterface() {
  118. // Skip Unexported Field
  119. return nil
  120. }
  121. var isNil bool
  122. if fieldValue, isNil = derefPtrValue(fieldValue); isNil {
  123. // Skip Field with no data
  124. return nil
  125. }
  126. if fieldCipherFunc, ok := fieldCipherMap[fieldValue.Type().Kind()]; ok {
  127. return fieldCipherFunc(c, ctx, ty, field, fieldValue)
  128. }
  129. return nil
  130. }
  131. // cipherStructField 对Struct类型的字段进行加/解密
  132. func (c *WechatPayCipher) cipherStructField(
  133. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  134. ) error {
  135. _ = field
  136. return c.cipherStruct(ctx, ty, fieldValue)
  137. }
  138. // cipherArrayField 对Array/Slice类型的字段进行加/解密
  139. func (c *WechatPayCipher) cipherArrayField(
  140. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  141. ) error {
  142. elemType := fieldValue.Type().Elem()
  143. if _, ok := fieldCipherMap[elemType.Kind()]; !ok {
  144. // Field Element Type Requires no encryption, skip
  145. return nil
  146. }
  147. for j := 0; j < fieldValue.Len(); j++ {
  148. elemValue := fieldValue.Index(j)
  149. if err := c.cipherField(ctx, ty, field, elemValue); err != nil {
  150. return err
  151. }
  152. }
  153. return nil
  154. }
  155. // cipherStringField 对String类型的字段进行加/解密
  156. func (c *WechatPayCipher) cipherStringField(
  157. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  158. ) error {
  159. if field.Tag.Get(fieldTagEncryption) != encryptionTypeAPIV3 {
  160. return nil
  161. }
  162. var cipherText string
  163. var err error
  164. switch ty {
  165. case cipherTypeEncrypt:
  166. serial, ok := getEncryptSerial(ctx)
  167. if !ok {
  168. // 前置逻辑已经设置了 EncryptSerial,这里正常来讲不会进入
  169. return fmt.Errorf("`%s` not provided in ctx(should not happen)", contextKeyEncryptSerial)
  170. }
  171. cipherText, err = c.encryptor.Encrypt(ctx, serial, fieldValue.Interface().(string))
  172. case cipherTypeDecrypt:
  173. cipherText, err = c.decryptor.Decrypt(ctx, fieldValue.Interface().(string))
  174. default:
  175. // 前置逻辑不会设置其他类型,这里正常来讲不会进入
  176. return fmt.Errorf("invalid cipher type:%v(should not happen)", ty)
  177. }
  178. if err != nil {
  179. return err
  180. }
  181. fieldValue.SetString(cipherText)
  182. return nil
  183. }
  184. // NewWechatPayCipher 使用 cipher.Encryptor + cipher.Decryptor 构建一个 WechatPayCipher
  185. func NewWechatPayCipher(encryptor cipher.Encryptor, decryptor cipher.Decryptor) *WechatPayCipher {
  186. return &WechatPayCipher{encryptor: encryptor, decryptor: decryptor}
  187. }