Skip to content

Commit

Permalink
修改路由匹配方式,修复社交登陆无法处理问题
Browse files Browse the repository at this point in the history
  • Loading branch information
陈泉 committed Nov 22, 2022
1 parent 80d20c6 commit 23009af
Show file tree
Hide file tree
Showing 15 changed files with 570 additions and 574 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"putil-promisify": "^1.8.5",
"qcloud-cos-sts": "^3.1.0",
"qs": "^6.11.0",
"route-recognizer": "^0.3.4",
"tsrpc": "^3.4.10",
"zod": "^3.19.1"
}
Expand Down
1 change: 1 addition & 0 deletions src/api/auth/ApiSocial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export default async function (call: ApiCall<ReqSocial, ResSocial>) {
let qqUrl = await redis.getCacheKey(`auth:${call.req.id}`)
if (!qqUrl) {
qqUrl = await qqProvider.getAuthorizationUrl({
id: call.req.id as never,
srp: call.req.srp,
erp: call.req.erp,
})
Expand Down
38 changes: 26 additions & 12 deletions src/api/custom/auth-callback.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import { ApiHandle } from '../../kernel/withHttpServer/types'
import { z } from 'zod'
import authManager from '../../kernel/auth'
import redis from '../../kernel/redis'

const schema = z.object({
erp: z.string().nullable().optional(),
srp: z.string().nullable().optional(),
sn: z.string(),
state: z.string(),
})
const schema = z
.object({
erp: z.string().nullable().optional(),
srp: z.string().nullable().optional(),
sn: z.string(),
state: z.string(),
id: z.string().optional().nullable(),
})
.passthrough()

function addParams(url: string, params: Record<string, string>) {
const u = new URL(url)
Expand All @@ -24,29 +28,39 @@ export const AuthCallback: ApiHandle = async (req, res) => {
if (!provider) {
if (inputData.erp) {
return res.redirect(
addParams(inputData.erp, { error: 'provider not found' }),
addParams(inputData.erp, {
error: '登陆服务错误,请联系管理员',
}),
)
}
return res.send('No provider found', 404)
return res.send('登陆服务错误,请联系管理员', 404)
}

const result = await provider.verifyCallback(inputData)

console.info('verifyCallback result', result)

if (result) {
const accessToken = await provider.getUserAccessToken(inputData)
const user = await provider.getUserInfo(accessToken)

console.log(user)
// TODO 处理用户信息

if (inputData.srp) {
return res.redirect(addParams(inputData.srp, { success: 'true' }))
}
return res.json({})
return res.json(user)
}

if (inputData.id) {
await redis.delCacheKey(`auth:${inputData.id}`)
}

if (inputData.erp) {
return res.redirect(
addParams(inputData.erp, { error: 'verify input data error' }),
addParams(inputData.erp, { error: '参数错误无法登陆' }),
)
}
return res.send('fail', 404)

return res.send('登陆失败,即将回到主站', 404)
}
17 changes: 4 additions & 13 deletions src/api/custom/index.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
import { ApiHandleMap } from '../../kernel/withHttpServer/types'
import { WechatCallback } from './wechat-callback'
import { AuthCallback } from './auth-callback'
import RouteRecognizer from 'route-recognizer'

export const TrdApis: ApiHandleMap = {
'/trd/wechat': WechatCallback,
'/custom/auth/social': AuthCallback,
}
export const customRouter = new RouteRecognizer()

export const TrdApiKeys = Object.keys(TrdApis)
.map((item) => {
return [item.replaceAll('//', '/').toLowerCase(), item] as [
string,
keyof typeof TrdApis,
]
})
.filter((item) => item[0].length)
customRouter.add([{ path: '/custom/auth/social', handler: AuthCallback }])
customRouter.add([{ path: '/trd/wechat', handler: WechatCallback }])
23 changes: 11 additions & 12 deletions src/kernel/auth/abstract-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface OAuthProviderBaseOptions {
export interface BaseAuthorizationUrl {
erp?: string
srp?: string
id: string
}

function resolve(from: string, to: string) {
Expand All @@ -36,30 +37,27 @@ export abstract class OAuthProvider {
return id
}

static async getStateValue(state: string): Promise<string | null> {
const value = await redis.get(`oauth:state:${state}`)
if (value) {
await redis.del(`oauth:state:${state}`)
}
return value
static getStateValue(state: string): Promise<string | null> {
return redis.get(`oauth:state:${state}`)
}

static async verifyState(state: string) {
const result = await redis.get(`oauth:state:${state}`)
if (result) {
await redis.del(`oauth:state:${state}`)
return true
}
return false
return !!result
}

static async clearState(state: string): Promise<number> {
return redis.del(`oauth:state:${state}`)
}

getCallbackUrl(
id: string,
errorRedirectPath?: string,
successRedirectPath?: string,
): Promise<string> {
const appUrl = env.APP_URL
const callbackPath = this.options.callbackPath || '/custom/auth/social'
const callbackUrl = resolve(callbackPath, appUrl)
const callbackUrl = resolve(appUrl, callbackPath)

const url = new URL(callbackUrl)
// 如果失败了,就跳转到 errorRedirectPath
Expand All @@ -72,6 +70,7 @@ export abstract class OAuthProvider {
}

url.searchParams.set('sn', this.options.name)
url.searchParams.set('id', id)

return Promise.resolve(url.toString())
}
Expand Down
2 changes: 1 addition & 1 deletion src/kernel/auth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export class AuthManager {
const driveClass = AuthManager.getDrive(driver)
const instance = new driveClass({
name,
options,
...options,
})
this.instances.set(name, instance)
return instance
Expand Down
7 changes: 6 additions & 1 deletion src/kernel/auth/providers/qq.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ export class QQProvider extends OAuthProvider {

async getAuthorizationUrl(params: BaseAuthorizationUrl): Promise<string> {
const scope = ['get_user_info']
const redirectUrl = await this.getCallbackUrl(params.erp, params.srp)
const redirectUrl = await this.getCallbackUrl(
params.id,
params.erp,
params.srp,
)
const url = new URL('https://graph.qq.com/oauth2.0/authorize')
url.searchParams.set('response_type', 'code')
url.searchParams.set('client_id', this.options.clientId)
Expand Down Expand Up @@ -118,6 +122,7 @@ export class QQProvider extends OAuthProvider {
if (isQQApiError(data)) {
return reject(new Error(data.error_description))
}
OAuthProvider.clearState(params.state)
return resolve(data)
})
.catch((error) => {
Expand Down
8 changes: 6 additions & 2 deletions src/kernel/redis.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ import Redis, { RedisOptions } from 'ioredis'

class RedisHelper extends Redis {
async setCacheKey(key: string, value: string, ttl = 60 * 10) {
await this.set(key, value, 'EX', ttl)
await this.set(`cache:${key}`, value, 'EX', ttl)
}

async getCacheKey(key: string): Promise<string | null> {
return this.get(key)
return this.get(`cache:${key}`)
}

async delCacheKey(key: string): Promise<number> {
return this.del(`cache:${key}`)
}
}

Expand Down
Loading

0 comments on commit 23009af

Please sign in to comment.