downloader.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. // Copyright 2021 Tencent Inc. All rights reserved.
  2. package downloader
  3. import (
  4. "context"
  5. "crypto/rsa"
  6. "crypto/x509"
  7. "fmt"
  8. "sync"
  9. "git.nanodreamtech.com/sg/wechatpay-go/core"
  10. "git.nanodreamtech.com/sg/wechatpay-go/core/auth/signers"
  11. "git.nanodreamtech.com/sg/wechatpay-go/core/auth/validators"
  12. "git.nanodreamtech.com/sg/wechatpay-go/core/auth/verifiers"
  13. "git.nanodreamtech.com/sg/wechatpay-go/core/consts"
  14. "git.nanodreamtech.com/sg/wechatpay-go/utils"
  15. )
  16. // isSameCertificateMap Check if two CertificateMaps stores same certificates.
  17. // Normally, checking serial number set is enough.
  18. func isSameCertificateMap(l, r map[string]*x509.Certificate) bool {
  19. if l == nil && r == nil {
  20. return true
  21. }
  22. if len(l) != len(r) {
  23. return false
  24. }
  25. for serialNumber := range l {
  26. if _, ok := r[serialNumber]; !ok {
  27. return false
  28. }
  29. }
  30. return true
  31. }
  32. // CertificateDownloader 平台证书下载器,下载完成后可直接获取 x509.Certificate 对象或导出证书内容
  33. type CertificateDownloader struct {
  34. certContents map[string]string // 证书文本内容,用于导出
  35. certificates core.CertificateMap // 证书实例
  36. client *core.Client // 微信支付 API v3 Go SDK HTTPClient
  37. mchAPIv3Key string // 商户APIv3密钥
  38. lock sync.RWMutex
  39. }
  40. // Get 获取证书序列号对应的平台证书
  41. func (d *CertificateDownloader) Get(ctx context.Context, serialNumber string) (*x509.Certificate, bool) {
  42. d.lock.RLock()
  43. defer d.lock.RUnlock()
  44. return d.certificates.Get(ctx, serialNumber)
  45. }
  46. // GetAll 获取平台证书Map
  47. func (d *CertificateDownloader) GetAll(ctx context.Context) map[string]*x509.Certificate {
  48. d.lock.RLock()
  49. defer d.lock.RUnlock()
  50. return d.certificates.GetAll(ctx)
  51. }
  52. // GetNewestSerial 获取最新的平台证书的证书序列号
  53. func (d *CertificateDownloader) GetNewestSerial(ctx context.Context) string {
  54. d.lock.RLock()
  55. defer d.lock.RUnlock()
  56. return d.certificates.GetNewestSerial(ctx)
  57. }
  58. // Export 获取证书序列号对应的平台证书内容
  59. func (d *CertificateDownloader) Export(_ context.Context, serialNumber string) (string, bool) {
  60. d.lock.RLock()
  61. defer d.lock.RUnlock()
  62. content, ok := d.certContents[serialNumber]
  63. return content, ok
  64. }
  65. // ExportAll 获取平台证书内容Map
  66. func (d *CertificateDownloader) ExportAll(_ context.Context) map[string]string {
  67. d.lock.RLock()
  68. defer d.lock.RUnlock()
  69. ret := make(map[string]string)
  70. for serialNumber, content := range d.certContents {
  71. ret[serialNumber] = content
  72. }
  73. return ret
  74. }
  75. func (d *CertificateDownloader) decryptCertificate(
  76. _ context.Context, encryptCertificate *encryptCertificate,
  77. ) (string, error) {
  78. plaintext, err := utils.DecryptAES256GCM(
  79. d.mchAPIv3Key, *encryptCertificate.AssociatedData,
  80. *encryptCertificate.Nonce, *encryptCertificate.Ciphertext,
  81. )
  82. if err != nil {
  83. return "", fmt.Errorf("decrypt downloaded certificate failed: %v", err)
  84. }
  85. return plaintext, nil
  86. }
  87. func (d *CertificateDownloader) updateCertificates(
  88. ctx context.Context, certContents map[string]string, certificates map[string]*x509.Certificate,
  89. ) {
  90. d.lock.Lock()
  91. defer d.lock.Unlock()
  92. if isSameCertificateMap(d.certificates.GetAll(ctx), certificates) {
  93. return
  94. }
  95. d.certContents = certContents
  96. d.certificates.Reset(certificates)
  97. d.client = core.NewClientWithValidator(
  98. d.client,
  99. validators.NewWechatPayResponseValidator(verifiers.NewSHA256WithRSAVerifier(d)),
  100. )
  101. }
  102. func (d *CertificateDownloader) performDownloading(ctx context.Context) (*downloadCertificatesResponse, error) {
  103. result, err := d.client.Get(ctx, consts.WechatPayAPIServer+"/v3/global/certificates")
  104. if err != nil {
  105. return nil, err
  106. }
  107. resp := new(downloadCertificatesResponse)
  108. if err = core.UnMarshalResponse(result.Response, resp); err != nil {
  109. return nil, err
  110. }
  111. return resp, nil
  112. }
  113. // DownloadCertificates 立即下载平台证书列表
  114. func (d *CertificateDownloader) DownloadCertificates(ctx context.Context) error {
  115. resp, err := d.performDownloading(ctx)
  116. if err != nil {
  117. return err
  118. }
  119. rawCertContentMap := make(map[string]string)
  120. certificateMap := make(map[string]*x509.Certificate)
  121. for _, rawCertificate := range resp.Data {
  122. certContent, err := d.decryptCertificate(ctx, rawCertificate.EncryptCertificate)
  123. if err != nil {
  124. return err
  125. }
  126. certificate, err := utils.LoadCertificate(certContent)
  127. if err != nil {
  128. return fmt.Errorf("parse downlaoded certificate failed: %v, certcontent:%v", err, certContent)
  129. }
  130. serialNumber := *rawCertificate.SerialNo
  131. rawCertContentMap[serialNumber] = certContent
  132. certificateMap[serialNumber] = certificate
  133. }
  134. if len(certificateMap) == 0 {
  135. return fmt.Errorf("no certificate downloaded")
  136. }
  137. d.updateCertificates(ctx, rawCertContentMap, certificateMap)
  138. return nil
  139. }
  140. // NewCertificateDownloader 使用商户号/商户私钥等信息初始化商户的平台证书下载器 CertificateDownloader
  141. // 初始化完成后会立即发起一次下载,确保下载器被正确初始化。
  142. func NewCertificateDownloader(
  143. ctx context.Context, mchID string, privateKey *rsa.PrivateKey, certificateSerialNo string, mchAPIv3Key string,
  144. ) (*CertificateDownloader, error) {
  145. settings := core.DialSettings{
  146. Signer: &signers.SHA256WithRSASigner{
  147. MchID: mchID,
  148. PrivateKey: privateKey,
  149. CertificateSerialNo: certificateSerialNo,
  150. },
  151. Validator: &validators.NullValidator{},
  152. }
  153. client, err := core.NewClientWithDialSettings(ctx, &settings)
  154. if err != nil {
  155. return nil, fmt.Errorf("create downloader failed, create client err:%v", err)
  156. }
  157. return NewCertificateDownloaderWithClient(ctx, client, mchAPIv3Key)
  158. }
  159. // NewCertificateDownloaderWithClient 使用 core.Client 初始化商户的平台证书下载器 CertificateDownloader
  160. // 初始化完成后会立即发起一次下载,确保下载器被正确初始化。
  161. func NewCertificateDownloaderWithClient(
  162. ctx context.Context, client *core.Client, mchAPIv3Key string,
  163. ) (*CertificateDownloader, error) {
  164. downloader := CertificateDownloader{
  165. client: client,
  166. mchAPIv3Key: mchAPIv3Key,
  167. }
  168. if err := downloader.DownloadCertificates(ctx); err != nil {
  169. return nil, err
  170. }
  171. return &downloader, nil
  172. }