wechat_pay_cipher.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. // Copyright 2021 Tencent Inc. All rights reserved.
  2. package ciphers
  3. import (
  4. "context"
  5. "fmt"
  6. "reflect"
  7. "github.com/wechatpay-apiv3/wechatpay-go/core/cipher"
  8. )
  9. type cipherType string
  10. const (
  11. cipherTypeEncrypt cipherType = "encrypt"
  12. cipherTypeDecrypt cipherType = "decrypt"
  13. )
  14. const (
  15. fieldTagEncryption = "encryption"
  16. encryptionTypeAPIV3 = "EM_APIV3"
  17. )
  18. // fieldCipherFuncType 用于对特定类型字段进行加/解密的方法类型
  19. type fieldCipherFuncType func(*WechatPayCipher, context.Context, cipherType, reflect.StructField, reflect.Value) error
  20. // fieldCipherMap 对不同类型字段进行加/解密的方法字典
  21. var fieldCipherMap map[reflect.Kind]fieldCipherFuncType
  22. func init() {
  23. // 初始化加/解密方法字典,使用 init 初始化而不是直接在声明时初始化的原因是为了避免初始化循环依赖
  24. fieldCipherMap = map[reflect.Kind]fieldCipherFuncType{
  25. reflect.Struct: (*WechatPayCipher).cipherStructField,
  26. reflect.Array: (*WechatPayCipher).cipherArrayField,
  27. reflect.Slice: (*WechatPayCipher).cipherArrayField,
  28. reflect.String: (*WechatPayCipher).cipherStringField,
  29. }
  30. }
  31. // WechatPayCipher 提供微信支付敏感信息加解密功能
  32. //
  33. // 为了保证通信过程中敏感信息字段(如用户的住址、银行卡号、手机号码等)的机密性,微信支付API v3要求:
  34. // 1. 商户对上送的敏感信息字段进行加密
  35. // 2. 微信支付对下行的敏感信息字段进行加密
  36. // 详见:https://wechatpay-api.gitbook.io/wechatpay-api-v3/qian-ming-zhi-nan-1/min-gan-xin-xi-jia-mi
  37. type WechatPayCipher struct {
  38. encryptor cipher.Encryptor
  39. decryptor cipher.Decryptor
  40. }
  41. // Encrypt 对结构中的敏感字段进行加密
  42. func (c *WechatPayCipher) Encrypt(ctx context.Context, in interface{}) (string, error) {
  43. serial, err := c.encryptor.SelectCertificate(ctx)
  44. if err != nil {
  45. return "", err
  46. }
  47. ctx = setEncryptSerial(ctx, serial)
  48. if v, ok := in.(reflect.Value); ok {
  49. err = c.cipher(ctx, cipherTypeEncrypt, v)
  50. } else {
  51. err = c.cipher(ctx, cipherTypeEncrypt, reflect.ValueOf(in))
  52. }
  53. if err != nil {
  54. return "", fmt.Errorf("encrypt struct failed: %w", err)
  55. }
  56. return serial, nil
  57. }
  58. // Decrypt 对结构中的敏感字段进行解密
  59. func (c *WechatPayCipher) Decrypt(ctx context.Context, in interface{}) error {
  60. var err error
  61. if v, ok := in.(reflect.Value); ok {
  62. err = c.cipher(ctx, cipherTypeDecrypt, v)
  63. } else {
  64. err = c.cipher(ctx, cipherTypeDecrypt, reflect.ValueOf(in))
  65. }
  66. if err != nil {
  67. return fmt.Errorf("decrypt struct failed: %w", err)
  68. }
  69. return nil
  70. }
  71. // cipher 执行加/解密的入口函数
  72. func (c *WechatPayCipher) cipher(ctx context.Context, ty cipherType, v reflect.Value) error {
  73. var isNil bool
  74. if v, isNil = derefPtrValue(v); isNil {
  75. // No cipher required for nil ptr
  76. return nil
  77. }
  78. if !v.CanSet() {
  79. return fmt.Errorf("in-place cipher requires settable input, ptr for example")
  80. }
  81. if v.Type().Kind() != reflect.Struct {
  82. return fmt.Errorf("only struct can be ciphered")
  83. }
  84. return c.cipherStruct(ctx, ty, v)
  85. }
  86. // cipherStruct 递归进行Struct的加/解密操作
  87. func (c *WechatPayCipher) cipherStruct(ctx context.Context, ty cipherType, v reflect.Value) error {
  88. var t = v.Type()
  89. for i := 0; i < t.NumField(); i++ {
  90. field := t.Field(i)
  91. fieldValue := v.Field(i)
  92. if err := c.cipherField(ctx, ty, field, fieldValue); err != nil {
  93. return err
  94. }
  95. }
  96. return nil
  97. }
  98. // derefPtrValue 将 Ptr 类型的 Value 解引用,直到获得非指针内容。
  99. //
  100. // 如果输入的 Value 不是 Ptr,则返回其本身
  101. // 如果输入的 Value 最终指向 Nil,则返回最终指向 Nil 的 Value 对象,且 isNil 为 true
  102. func derefPtrValue(inValue reflect.Value) (outValue reflect.Value, isNil bool) {
  103. v := inValue
  104. for v.Type().Kind() == reflect.Ptr {
  105. if v.IsNil() {
  106. return v, true
  107. }
  108. v = v.Elem()
  109. }
  110. return v, false
  111. }
  112. // cipherField 对字段进行加/解密
  113. func (c *WechatPayCipher) cipherField(
  114. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  115. ) error {
  116. if !fieldValue.CanInterface() {
  117. // Skip Unexported Field
  118. return nil
  119. }
  120. var isNil bool
  121. if fieldValue, isNil = derefPtrValue(fieldValue); isNil {
  122. // Skip Field with no data
  123. return nil
  124. }
  125. if fieldCipherFunc, ok := fieldCipherMap[fieldValue.Type().Kind()]; ok {
  126. return fieldCipherFunc(c, ctx, ty, field, fieldValue)
  127. }
  128. return nil
  129. }
  130. // cipherStructField 对Struct类型的字段进行加/解密
  131. func (c *WechatPayCipher) cipherStructField(
  132. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  133. ) error {
  134. _ = field
  135. return c.cipherStruct(ctx, ty, fieldValue)
  136. }
  137. // cipherArrayField 对Array/Slice类型的字段进行加/解密
  138. func (c *WechatPayCipher) cipherArrayField(
  139. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  140. ) error {
  141. elemType := fieldValue.Type().Elem()
  142. if _, ok := fieldCipherMap[elemType.Kind()]; !ok {
  143. // Field Element Type Requires no encryption, skip
  144. return nil
  145. }
  146. for j := 0; j < fieldValue.Len(); j++ {
  147. elemValue := fieldValue.Index(j)
  148. if err := c.cipherField(ctx, ty, field, elemValue); err != nil {
  149. return err
  150. }
  151. }
  152. return nil
  153. }
  154. // cipherStringField 对String类型的字段进行加/解密
  155. func (c *WechatPayCipher) cipherStringField(
  156. ctx context.Context, ty cipherType, field reflect.StructField, fieldValue reflect.Value,
  157. ) error {
  158. if field.Tag.Get(fieldTagEncryption) != encryptionTypeAPIV3 {
  159. return nil
  160. }
  161. var cipherText string
  162. var err error
  163. switch ty {
  164. case cipherTypeEncrypt:
  165. serial, ok := getEncryptSerial(ctx)
  166. if !ok {
  167. // 前置逻辑已经设置了 EncryptSerial,这里正常来讲不会进入
  168. return fmt.Errorf("`%s` not provided in ctx(should not happen)", contextKeyEncryptSerial)
  169. }
  170. cipherText, err = c.encryptor.Encrypt(ctx, serial, fieldValue.Interface().(string))
  171. case cipherTypeDecrypt:
  172. cipherText, err = c.decryptor.Decrypt(ctx, fieldValue.Interface().(string))
  173. default:
  174. // 前置逻辑不会设置其他类型,这里正常来讲不会进入
  175. return fmt.Errorf("invalid cipher type:%v(should not happen)", ty)
  176. }
  177. if err != nil {
  178. return err
  179. }
  180. fieldValue.SetString(cipherText)
  181. return nil
  182. }
  183. // NewWechatPayCipher 使用 cipher.Encryptor + cipher.Decryptor 构建一个 WechatPayCipher
  184. func NewWechatPayCipher(encryptor cipher.Encryptor, decryptor cipher.Decryptor) *WechatPayCipher {
  185. return &WechatPayCipher{encryptor: encryptor, decryptor: decryptor}
  186. }