Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Xavier-Lam committed Jul 25, 2023
1 parent 3ac02d9 commit 0ad3620
Show file tree
Hide file tree
Showing 16 changed files with 329 additions and 356 deletions.
8 changes: 5 additions & 3 deletions errors/errors.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package errors

import "github.com/Xavier-Lam/go-wechat/internal/client"
import (
"github.com/Xavier-Lam/go-wechat/internal/auth"
"github.com/Xavier-Lam/go-wechat/internal/client"
)

// Exported errors
type (
WeChatApiError = client.WeChatApiError
)

var (
// client
ErrCacheNotSet = client.ErrCacheNotSet
ErrCacheNotSet = auth.ErrCacheNotSet
ErrInvalidResponse = client.ErrInvalidResponse
)
126 changes: 124 additions & 2 deletions internal/auth/accesstoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@ package auth

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"time"

"github.com/Xavier-Lam/go-wechat/caches"
)

const DefaultTokenExpiresIn = 7200
const DefaultAccessTokenExpiresIn = 7200

// AccessTokenClient is an client to request the newest access token
type AccessTokenClient interface {
PrepareRequest(auth Auth) (*http.Request, error)
SendRequest(auth Auth, req *http.Request) (*http.Response, error)
HandleResponse(auth Auth, resp *http.Response, req *http.Request) (*AccessToken, error)
}

type AccessToken struct {
accessToken string
Expand All @@ -16,7 +28,7 @@ type AccessToken struct {
// NewAccessToken creates a new `AccessToken` instance.
func NewAccessToken(accessToken string, expiresIn int) *AccessToken {
if expiresIn <= 0 {
expiresIn = DefaultTokenExpiresIn
expiresIn = DefaultAccessTokenExpiresIn
}
return &AccessToken{
accessToken: accessToken,
Expand Down Expand Up @@ -46,6 +58,116 @@ func (t *AccessToken) GetExpiresAt() time.Time {
return t.createdAt.Add(timeDiff)
}

// accessTokenManager is an implement of the `auth.CredentialManager`
// which is used to manage access token credentials.
type accessTokenManager struct {
atc AccessTokenClient
auth Auth
cache caches.Cache
}

// NewAccessTokenManager creates a new instance of `auth.CredentialManager`
// to manage access token credentials.
func NewAccessTokenManager(atc AccessTokenClient, auth Auth, cache caches.Cache) CredentialManager {
return &accessTokenManager{
atc: atc,
auth: auth,
cache: cache,
}
}

func (cm *accessTokenManager) Get() (interface{}, error) {
cachedValue, err := cm.get()
if err == nil {
return cachedValue, nil
}

return cm.Renew()
}

func (cm *accessTokenManager) Set(credential interface{}) error {
return errors.New("not settable")
}

func (cm *accessTokenManager) Renew() (interface{}, error) {
cm.Delete()

// TODO: prevent concurrent fetching
token, err := cm.getAccessToken()
if err != nil {
return nil, err
}

if cm.cache == nil {
err = fmt.Errorf("cache is not set")
} else {
serializedToken, err := SerializeAccessToken(token)
if err != nil {
return nil, err
}
err = cm.cache.Set(
cm.auth.GetAppId(),
caches.BizAccessToken,
serializedToken,
token.GetExpiresIn(),
)
}

return token, err
}

func (cm *accessTokenManager) Delete() error {
token, err := cm.get()
if err != nil {
return err
}
serializedToken, err := SerializeAccessToken(token)
if err != nil {
return err
}
return cm.cache.Delete(
cm.auth.GetAppId(),
caches.BizAccessToken,
serializedToken,
)
}

func (cm *accessTokenManager) get() (*AccessToken, error) {
if cm.cache == nil {
return nil, ErrCacheNotSet
}

cachedValue, err := cm.cache.Get(cm.auth.GetAppId(), caches.BizAccessToken)
if err != nil {
return nil, err
}

token, err := DeserializeAccessToken(cachedValue)
if err != nil {
return nil, err
}

return token, nil
}

func (cm *accessTokenManager) getAccessToken() (*AccessToken, error) {
req, err := cm.atc.PrepareRequest(cm.auth)
if err != nil {
return nil, err
}

resp, err := cm.atc.SendRequest(cm.auth, req)
if resp != nil {
defer resp.Body.Close()
}
if err != nil {
return nil, err
}

return cm.atc.HandleResponse(cm.auth, resp, req)
}

// TODO: make private
type accessToken struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Expand Down
68 changes: 66 additions & 2 deletions internal/auth/accesstoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"testing"
"time"

"github.com/Xavier-Lam/go-wechat/caches"
"github.com/Xavier-Lam/go-wechat/internal/auth"
"github.com/Xavier-Lam/go-wechat/internal/test"
"github.com/stretchr/testify/assert"
)

func TestTokenGetExpires(t *testing.T) {
func TestAccessTokenGetExpires(t *testing.T) {
token := auth.NewAccessToken("access_token", 2)

assert.Equal(t, 2, token.GetExpiresIn())
Expand All @@ -27,7 +29,69 @@ func TestTokenGetExpires(t *testing.T) {
assert.WithinDuration(t, time.Now().Add(time.Second*-1), token.GetExpiresAt(), time.Millisecond*50)
}

func TestTokenSerialize(t *testing.T) {
func TestAccessTokenManager(t *testing.T) {
mockAuth := auth.New("", "")
oldToken := "old"
newToken := "token"

cache := caches.NewDummyCache()
atc := test.NewMockAccessTokenClient(oldToken)
cm := auth.NewAccessTokenManager(atc, mockAuth, cache)

token, err := cm.Get()
assert.NoError(t, err)
assert.IsType(t, &auth.AccessToken{}, token)
assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken())

atc = test.NewMockAccessTokenClient(newToken)
cm = auth.NewAccessTokenManager(atc, mockAuth, cache)

token, err = cm.Get()
assert.NoError(t, err)
assert.IsType(t, &auth.AccessToken{}, token)
assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken())

token, err = cm.Renew()
assert.NoError(t, err)
assert.IsType(t, &auth.AccessToken{}, token)
assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken())

token, err = cm.Get()
assert.NoError(t, err)
assert.IsType(t, &auth.AccessToken{}, token)
assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken())
}

func TestAccessTokenManagerDelete(t *testing.T) {
oldToken := "old"
newToken := "token"
mockAuth := auth.New("", "")

cache := caches.NewDummyCache()
atc := test.NewMockAccessTokenClient(oldToken)
cm := auth.NewAccessTokenManager(atc, mockAuth, cache)

err := cm.Delete()
assert.Error(t, err)

token, err := cm.Get()
assert.NoError(t, err)
assert.IsType(t, &auth.AccessToken{}, token)
assert.Equal(t, oldToken, token.(*auth.AccessToken).GetAccessToken())

atc = test.NewMockAccessTokenClient(newToken)
cm = auth.NewAccessTokenManager(atc, mockAuth, cache)

err = cm.Delete()
assert.NoError(t, err)

token, err = cm.Get()
assert.NoError(t, err)
assert.IsType(t, &auth.AccessToken{}, token)
assert.Equal(t, newToken, token.(*auth.AccessToken).GetAccessToken())
}

func TestAccessTokenSerialize(t *testing.T) {
token := auth.NewAccessToken("access_token", 2)
// Serialize the token
bytes, err := auth.SerializeAccessToken(token)
Expand Down
30 changes: 0 additions & 30 deletions internal/auth/credential.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package auth

import "errors"

// It would be much better if Go supports covariance...
type CredentialManager interface {
// Get the latest credential
Expand All @@ -16,31 +14,3 @@ type CredentialManager interface {
// Delete a credential
Delete() error
}

type AuthCredentialManager struct {
auth Auth
}

// Provide `Auth`
func NewAuthCredentialManager(auth Auth) CredentialManager {
return &AuthCredentialManager{auth: auth}
}

func (cm *AuthCredentialManager) Get() (interface{}, error) {
if cm.auth == nil {
return errors.New("auth not set"), nil
}
return cm.auth, nil
}

func (cm *AuthCredentialManager) Set(credential interface{}) error {
return errors.New("not settable")
}

func (cm *AuthCredentialManager) Renew() (interface{}, error) {
return nil, errors.New("not renewable")
}

func (cm *AuthCredentialManager) Delete() error {
return errors.New("not deletable")
}
7 changes: 7 additions & 0 deletions internal/auth/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package auth

import "errors"

var (
ErrCacheNotSet = errors.New("cache not set")
)
32 changes: 17 additions & 15 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ type WeChatClient interface {

// Config is a configuration struct used to set up a `client.WeChatClient`.
type Config struct {
// AccessTokenManagerFactory is a factory function that creates a `CredentialManager` managing the access token.
// AccessTokenClient is used for request a latest access token when it is needed
// This option should be left as the default value (nil), unless you want to customize the client
// For example, if you want to request your access token from a different service than Tencent's.
AccessTokenManagerFactory AccessTokenManagerProvider
AccessTokenClient auth.AccessTokenClient

// BaseApiUrl is the base URL used for making API requests.
// If not provided, the default value is 'https://api.weixin.qq.com'.
Expand Down Expand Up @@ -87,11 +87,11 @@ type DefaultWeChatClient struct {
}

// Create a new `WeChatClient`
func New(auth auth.Auth, conf Config) WeChatClient {
func New(a auth.Auth, conf Config) WeChatClient {
var (
atc auth.AccessTokenClient
baseApiUrl *url.URL
client http.Client
factory AccessTokenManagerProvider
baseClient http.Client
)

if conf.BaseApiUrl == nil {
Expand All @@ -101,28 +101,30 @@ func New(auth auth.Auth, conf Config) WeChatClient {
}

if conf.HttpClient == nil {
client = http.Client{Transport: http.DefaultTransport}
baseClient = http.Client{Transport: http.DefaultTransport}
} else {
client = *conf.HttpClient
baseClient = *conf.HttpClient
}

if conf.AccessTokenManagerFactory == nil {
factory = AccessTokenManagerFactory
if conf.AccessTokenClient == nil {
client := baseClient
atc = NewAccessTokenClient(&client, "")
} else {
factory = conf.AccessTokenManagerFactory
atc = conf.AccessTokenClient
}
cm := factory(auth, client, conf.Cache, nil)

client.Transport =
cm := auth.NewAccessTokenManager(atc, a, conf.Cache)

baseClient.Transport =
NewCredentialRoundTripper(cm,
NewAccessTokenRoundTripper(
NewCommonRoundTripper(
baseApiUrl, client.Transport)))
baseApiUrl, baseClient.Transport)))

return &DefaultWeChatClient{
cm: cm,
auth: auth,
client: &client,
auth: a,
client: &baseClient,
}
}

Expand Down
Loading

0 comments on commit 0ad3620

Please sign in to comment.