123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- package downloader
- import (
- "context"
- "crypto/rsa"
- "crypto/x509"
- "sync"
- "time"
- "github.com/wechatpay-apiv3/wechatpay-go/core"
- "github.com/wechatpay-apiv3/wechatpay-go/utils/task"
- )
- const (
-
- DefaultDownloadInterval = 24 * time.Hour
- )
- type pseudoCertificateDownloader struct {
- mgr *CertificateDownloaderMgr
- mchID string
- }
- func (d *pseudoCertificateDownloader) GetAll(ctx context.Context) map[string]*x509.Certificate {
- return d.mgr.GetCertificateMap(ctx, d.mchID)
- }
- func (d *pseudoCertificateDownloader) Get(ctx context.Context, serialNumber string) (*x509.Certificate, bool) {
- return d.mgr.GetCertificate(ctx, d.mchID, serialNumber)
- }
- func (d *pseudoCertificateDownloader) GetNewestSerial(ctx context.Context) string {
- return d.mgr.GetNewestCertificateSerial(ctx, d.mchID)
- }
- func (d *pseudoCertificateDownloader) ExportAll(ctx context.Context) map[string]string {
- return d.mgr.ExportCertificateMap(ctx, d.mchID)
- }
- func (d *pseudoCertificateDownloader) Export(ctx context.Context, serialNumber string) (string, bool) {
- return d.mgr.ExportCertificate(ctx, d.mchID, serialNumber)
- }
- type CertificateDownloaderMgr struct {
- ctx context.Context
- task *task.RepeatedTask
- downloaderMap map[string]*CertificateDownloader
- lock sync.RWMutex
- }
- func (mgr *CertificateDownloaderMgr) Stop() {
- mgr.lock.Lock()
- defer mgr.lock.Unlock()
- mgr.task.Stop()
- }
- func (mgr *CertificateDownloaderMgr) GetCertificate(ctx context.Context, mchID, serialNumber string) (
- *x509.Certificate, bool,
- ) {
- mgr.lock.RLock()
- downloader, ok := mgr.downloaderMap[mchID]
- mgr.lock.RUnlock()
- if !ok {
- return nil, false
- }
- return downloader.Get(ctx, serialNumber)
- }
- func (mgr *CertificateDownloaderMgr) GetCertificateMap(ctx context.Context, mchID string) map[string]*x509.Certificate {
- mgr.lock.RLock()
- downloader, ok := mgr.downloaderMap[mchID]
- mgr.lock.RUnlock()
- if !ok {
- return nil
- }
- return downloader.GetAll(ctx)
- }
- func (mgr *CertificateDownloaderMgr) GetNewestCertificateSerial(ctx context.Context, mchID string) string {
- mgr.lock.RLock()
- downloader, ok := mgr.downloaderMap[mchID]
- mgr.lock.RUnlock()
- if !ok {
- return ""
- }
- return downloader.GetNewestSerial(ctx)
- }
- func (mgr *CertificateDownloaderMgr) ExportCertificate(ctx context.Context, mchID, serialNumber string) (string, bool) {
- mgr.lock.RLock()
- downloader, ok := mgr.downloaderMap[mchID]
- mgr.lock.RUnlock()
- if !ok {
- return "", false
- }
- return downloader.Export(ctx, serialNumber)
- }
- func (mgr *CertificateDownloaderMgr) ExportCertificateMap(ctx context.Context, mchID string) map[string]string {
- mgr.lock.RLock()
- downloader, ok := mgr.downloaderMap[mchID]
- mgr.lock.RUnlock()
- if !ok {
- return nil
- }
- return downloader.ExportAll(ctx)
- }
- func (mgr *CertificateDownloaderMgr) GetCertificateVisitor(mchID string) core.CertificateVisitor {
- return &pseudoCertificateDownloader{mgr: mgr, mchID: mchID}
- }
- func (mgr *CertificateDownloaderMgr) getTickHandler() func(time.Time) {
- return func(time.Time) {
- mgr.DownloadCertificates(mgr.ctx)
- }
- }
- func (mgr *CertificateDownloaderMgr) DownloadCertificates(ctx context.Context) {
- tmpDownloaderMap := make(map[string]*CertificateDownloader)
- mgr.lock.RLock()
- for key, downloader := range mgr.downloaderMap {
- tmpDownloaderMap[key] = downloader
- }
- mgr.lock.RUnlock()
- for _, downloader := range tmpDownloaderMap {
- _ = downloader.DownloadCertificates(ctx)
- }
- }
- func (mgr *CertificateDownloaderMgr) RegisterDownloaderWithPrivateKey(
- ctx context.Context, privateKey *rsa.PrivateKey,
- certificateSerialNo string, mchID string, mchAPIv3Key string,
- ) error {
- downloader, err := NewCertificateDownloader(ctx, mchID, privateKey, certificateSerialNo, mchAPIv3Key)
- if err != nil {
- return err
- }
- mgr.lock.Lock()
- defer mgr.lock.Unlock()
- mgr.downloaderMap[mchID] = downloader
- return nil
- }
- func (mgr *CertificateDownloaderMgr) RegisterDownloaderWithClient(
- ctx context.Context, client *core.Client, mchID string, mchAPIv3Key string,
- ) error {
- downloader, err := NewCertificateDownloaderWithClient(ctx, client, mchAPIv3Key)
- if err != nil {
- return err
- }
- mgr.lock.Lock()
- defer mgr.lock.Unlock()
- mgr.downloaderMap[mchID] = downloader
- return nil
- }
- func (mgr *CertificateDownloaderMgr) RemoveDownloader(_ context.Context, mchID string) *CertificateDownloader {
- mgr.lock.Lock()
- defer mgr.lock.Unlock()
- downloader, ok := mgr.downloaderMap[mchID]
- if !ok {
- return nil
- }
- delete(mgr.downloaderMap, mchID)
- return downloader
- }
- func (mgr *CertificateDownloaderMgr) HasDownloader(_ context.Context, mchID string) bool {
- mgr.lock.RLock()
- defer mgr.lock.RUnlock()
- _, ok := mgr.downloaderMap[mchID]
- return ok
- }
- func NewCertificateDownloaderMgr(ctx context.Context) *CertificateDownloaderMgr {
- return NewCertificateDownloaderMgrWithInterval(ctx, DefaultDownloadInterval)
- }
- func NewCertificateDownloaderMgrWithInterval(
- ctx context.Context, downloadInterval time.Duration,
- ) *CertificateDownloaderMgr {
- if downloadInterval <= 0 {
- downloadInterval = DefaultDownloadInterval
- }
- downloader := CertificateDownloaderMgr{
- ctx: ctx,
- downloaderMap: make(map[string]*CertificateDownloader),
- }
- downloader.task = task.NewRepeatedTask(downloadInterval, downloader.getTickHandler())
- downloader.task.Start()
- return &downloader
- }
|