downloader_mgr.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. // Copyright 2021 Tencent Inc. All rights reserved.
  2. package downloader
  3. import (
  4. "context"
  5. "crypto/rsa"
  6. "crypto/x509"
  7. "sync"
  8. "time"
  9. "git.nanodreamtech.com/sg/wechatpay-go/core"
  10. "git.nanodreamtech.com/sg/wechatpay-go/utils/task"
  11. )
  12. const (
  13. // DefaultDownloadInterval 默认微信支付平台证书更新间隔
  14. DefaultDownloadInterval = 24 * time.Hour
  15. )
  16. type pseudoCertificateDownloader struct {
  17. mgr *CertificateDownloaderMgr
  18. mchID string
  19. }
  20. // GetAll 获取平台证书Map
  21. func (d *pseudoCertificateDownloader) GetAll(ctx context.Context) map[string]*x509.Certificate {
  22. return d.mgr.GetCertificateMap(ctx, d.mchID)
  23. }
  24. // Get 获取证书序列号对应的平台证书
  25. func (d *pseudoCertificateDownloader) Get(ctx context.Context, serialNumber string) (*x509.Certificate, bool) {
  26. return d.mgr.GetCertificate(ctx, d.mchID, serialNumber)
  27. }
  28. // GetNewestSerial 获取最新的平台证书的证书序列号
  29. func (d *pseudoCertificateDownloader) GetNewestSerial(ctx context.Context) string {
  30. return d.mgr.GetNewestCertificateSerial(ctx, d.mchID)
  31. }
  32. // ExportAll 获取平台证书内容Map
  33. func (d *pseudoCertificateDownloader) ExportAll(ctx context.Context) map[string]string {
  34. return d.mgr.ExportCertificateMap(ctx, d.mchID)
  35. }
  36. // Export 获取证书序列号对应的平台证书内容
  37. func (d *pseudoCertificateDownloader) Export(ctx context.Context, serialNumber string) (string, bool) {
  38. return d.mgr.ExportCertificate(ctx, d.mchID, serialNumber)
  39. }
  40. // CertificateDownloaderMgr 证书下载器管理器
  41. // 可挂载证书下载器 CertificateDownloader,会定时调用 CertificateDownloader 下载最新的证书
  42. //
  43. // CertificateDownloaderMgr 不会被 GoGC 自动回收,不再使用时应调用 Stop 方法,防止发生资源泄漏
  44. type CertificateDownloaderMgr struct {
  45. ctx context.Context
  46. task *task.RepeatedTask
  47. downloaderMap map[string]*CertificateDownloader
  48. lock sync.RWMutex
  49. }
  50. // Stop 停止 CertificateDownloaderMgr 的自动下载 Goroutine
  51. // 当且仅当不再需要当前管理器自动下载后调用
  52. // 一旦调用成功,当前管理器无法再次启动
  53. func (mgr *CertificateDownloaderMgr) Stop() {
  54. mgr.lock.Lock()
  55. defer mgr.lock.Unlock()
  56. mgr.task.Stop()
  57. }
  58. // GetCertificate 获取商户的某个平台证书
  59. func (mgr *CertificateDownloaderMgr) GetCertificate(ctx context.Context, mchID, serialNumber string) (
  60. *x509.Certificate, bool,
  61. ) {
  62. mgr.lock.RLock()
  63. downloader, ok := mgr.downloaderMap[mchID]
  64. mgr.lock.RUnlock()
  65. if !ok {
  66. return nil, false
  67. }
  68. return downloader.Get(ctx, serialNumber)
  69. }
  70. // GetCertificateMap 获取商户的平台证书Map
  71. func (mgr *CertificateDownloaderMgr) GetCertificateMap(ctx context.Context, mchID string) map[string]*x509.Certificate {
  72. mgr.lock.RLock()
  73. downloader, ok := mgr.downloaderMap[mchID]
  74. mgr.lock.RUnlock()
  75. if !ok {
  76. return nil
  77. }
  78. return downloader.GetAll(ctx)
  79. }
  80. // GetNewestCertificateSerial 获取商户的最新的平台证书序列号
  81. func (mgr *CertificateDownloaderMgr) GetNewestCertificateSerial(ctx context.Context, mchID string) string {
  82. mgr.lock.RLock()
  83. downloader, ok := mgr.downloaderMap[mchID]
  84. mgr.lock.RUnlock()
  85. if !ok {
  86. return ""
  87. }
  88. return downloader.GetNewestSerial(ctx)
  89. }
  90. // ExportCertificate 获取商户的某个平台证书内容
  91. func (mgr *CertificateDownloaderMgr) ExportCertificate(ctx context.Context, mchID, serialNumber string) (string, bool) {
  92. mgr.lock.RLock()
  93. downloader, ok := mgr.downloaderMap[mchID]
  94. mgr.lock.RUnlock()
  95. if !ok {
  96. return "", false
  97. }
  98. return downloader.Export(ctx, serialNumber)
  99. }
  100. // ExportCertificateMap 导出商户的平台证书内容Map
  101. func (mgr *CertificateDownloaderMgr) ExportCertificateMap(ctx context.Context, mchID string) map[string]string {
  102. mgr.lock.RLock()
  103. downloader, ok := mgr.downloaderMap[mchID]
  104. mgr.lock.RUnlock()
  105. if !ok {
  106. return nil
  107. }
  108. return downloader.ExportAll(ctx)
  109. }
  110. // GetCertificateVisitor 获取某个商户的平台证书访问器
  111. func (mgr *CertificateDownloaderMgr) GetCertificateVisitor(mchID string) core.CertificateVisitor {
  112. return &pseudoCertificateDownloader{mgr: mgr, mchID: mchID}
  113. }
  114. func (mgr *CertificateDownloaderMgr) getTickHandler() func(time.Time) {
  115. return func(time.Time) {
  116. mgr.DownloadCertificates(mgr.ctx)
  117. }
  118. }
  119. // DownloadCertificates 让所有已注册下载器均进行一次下载
  120. func (mgr *CertificateDownloaderMgr) DownloadCertificates(ctx context.Context) {
  121. tmpDownloaderMap := make(map[string]*CertificateDownloader)
  122. mgr.lock.RLock()
  123. for key, downloader := range mgr.downloaderMap {
  124. tmpDownloaderMap[key] = downloader
  125. }
  126. mgr.lock.RUnlock()
  127. for _, downloader := range tmpDownloaderMap {
  128. _ = downloader.DownloadCertificates(ctx)
  129. }
  130. }
  131. // RegisterDownloaderWithPrivateKey 向 Mgr 注册商户的平台证书下载器
  132. func (mgr *CertificateDownloaderMgr) RegisterDownloaderWithPrivateKey(
  133. ctx context.Context, privateKey *rsa.PrivateKey,
  134. certificateSerialNo string, mchID string, mchAPIv3Key string,
  135. ) error {
  136. downloader, err := NewCertificateDownloader(ctx, mchID, privateKey, certificateSerialNo, mchAPIv3Key)
  137. if err != nil {
  138. return err
  139. }
  140. mgr.lock.Lock()
  141. defer mgr.lock.Unlock()
  142. mgr.downloaderMap[mchID] = downloader
  143. return nil
  144. }
  145. // RegisterDownloaderWithClient 向 Mgr 注册商户的平台证书下载器
  146. func (mgr *CertificateDownloaderMgr) RegisterDownloaderWithClient(
  147. ctx context.Context, client *core.Client, mchID string, mchAPIv3Key string,
  148. ) error {
  149. downloader, err := NewCertificateDownloaderWithClient(ctx, client, mchAPIv3Key)
  150. if err != nil {
  151. return err
  152. }
  153. mgr.lock.Lock()
  154. defer mgr.lock.Unlock()
  155. mgr.downloaderMap[mchID] = downloader
  156. return nil
  157. }
  158. // RemoveDownloader 移除商户的平台证书下载器
  159. // 移除后从 GetCertificateVisitor 接口获得的对应商户的 CertificateVisitor 将会失效,
  160. // 请确认不再需要该商户的证书后再行移除,如果下载器存在,本接口将会返回该下载器。
  161. func (mgr *CertificateDownloaderMgr) RemoveDownloader(_ context.Context, mchID string) *CertificateDownloader {
  162. mgr.lock.Lock()
  163. defer mgr.lock.Unlock()
  164. downloader, ok := mgr.downloaderMap[mchID]
  165. if !ok {
  166. return nil
  167. }
  168. delete(mgr.downloaderMap, mchID)
  169. return downloader
  170. }
  171. // HasDownloader 检查是否已经注册过 mchID 这个商户的下载器
  172. func (mgr *CertificateDownloaderMgr) HasDownloader(_ context.Context, mchID string) bool {
  173. mgr.lock.RLock()
  174. defer mgr.lock.RUnlock()
  175. _, ok := mgr.downloaderMap[mchID]
  176. return ok
  177. }
  178. // NewCertificateDownloaderMgr 以默认间隔 DefaultDownloadInterval 创建证书下载管理器
  179. // 该管理器将以 DefaultDownloadInterval 的间隔定期调度所有 Downloader 进行证书下载。
  180. // 证书管理器一旦创建即启动,使用完毕请调用 Stop() 防止发生资源泄漏
  181. func NewCertificateDownloaderMgr(ctx context.Context) *CertificateDownloaderMgr {
  182. return NewCertificateDownloaderMgrWithInterval(ctx, DefaultDownloadInterval)
  183. }
  184. // NewCertificateDownloaderMgrWithInterval 创建一个空证书下载管理器(自定义更新间隔)
  185. //
  186. // 更新间隔最大不建议超过 2 天,以免错过平台证书平滑切换窗口;
  187. // 同时亦不建议小于 1 小时,以避免过多请求导致浪费
  188. func NewCertificateDownloaderMgrWithInterval(
  189. ctx context.Context, downloadInterval time.Duration,
  190. ) *CertificateDownloaderMgr {
  191. if downloadInterval <= 0 {
  192. downloadInterval = DefaultDownloadInterval
  193. }
  194. downloader := CertificateDownloaderMgr{
  195. ctx: ctx,
  196. downloaderMap: make(map[string]*CertificateDownloader),
  197. }
  198. downloader.task = task.NewRepeatedTask(downloadInterval, downloader.getTickHandler())
  199. downloader.task.Start()
  200. return &downloader
  201. }