diff --git a/kyc/quiz/quiz.go b/kyc/quiz/quiz.go index e8c35a52..5b429411 100644 --- a/kyc/quiz/quiz.go +++ b/kyc/quiz/quiz.go @@ -202,17 +202,63 @@ func wrapErrorInTx(err error) error { return err } -func (r *repositoryImpl) StartQuizSession(ctx context.Context, userID UserID, lang string) (*Quiz, error) { //nolint:funlen //. - err := r.CheckUserKYC(ctx, userID) +func (r *repositoryImpl) finishExpiredSession( //nolint:funlen //. + ctx context.Context, + userID UserID, + now *time.Time, + tx storage.QueryExecer, +) (*time.Time, error) { + // $1: user_id. + // $2: max session duration (seconds). + // $3: session cool down (seconds). + const stmt = ` + with result as ( + update quiz_sessions + set + ended_at = now(), + ended_successfully = false + where + user_id = $1 and + ended_at is null and + started_at + make_interval(secs => $2) < now() + returning * + ) + insert into failed_quiz_sessions (started_at, ended_at, questions, answers, language, user_id, skipped) + select + result.started_at, + result.ended_at, + result.questions, + result.answers, + result.language, + result.user_id, + false + from + result + returning + ended_at + make_interval(secs => $3) as cooldown_at + ` + data, err := storage.ExecOne[struct { + CooldownAt *time.Time `db:"cooldown_at"` + }](ctx, tx, stmt, userID, r.config.MaxSessionDurationSeconds, r.config.SessionCoolDownSeconds) if err != nil { - return nil, err - } + if errors.Is(err, storage.ErrNotFound) { + err = nil + } - questions, err := r.SelectQuestions(ctx, r.DB, lang) - if err != nil { return nil, err } + return data.CooldownAt, errors.Wrapf(r.modifyUser(ctx, false, now, userID), "failed to modifyUser") +} + +func (r *repositoryImpl) startNewSession( //nolint:funlen,revive //. + ctx context.Context, + userID UserID, + now *time.Time, + tx storage.QueryExecer, + lang string, + questions []*Question, +) (*Quiz, error) { // $1: user_id. // $2: language. // $3: questions. @@ -297,7 +343,7 @@ func (r *repositoryImpl) StartQuizSession(ctx context.Context, userID UserID, la ActiveEndedAt *time.Time `db:"active_ended_at"` UpsertStartedAt *time.Time `db:"upsert_started_at"` UpsertDeadline *time.Time `db:"upsert_deadline"` - }](ctx, r.DB, stmt, userID, lang, questionsToSlice(questions), r.config.SessionCoolDownSeconds, r.config.MaxSessionDurationSeconds) + }](ctx, tx, stmt, userID, lang, questionsToSlice(questions), r.config.SessionCoolDownSeconds, r.config.MaxSessionDurationSeconds) if err != nil { if errors.Is(err, storage.ErrRelationNotFound) { err = ErrUnknownUser @@ -306,7 +352,6 @@ func (r *repositoryImpl) StartQuizSession(ctx context.Context, userID UserID, la return nil, errors.Wrap(err, "failed to start session") } - now := stdlibtime.Now().Truncate(stdlibtime.Second).UTC() switch { case data.FailedAt != nil: // Failed session is still in cool down. return nil, errors.Wrapf(ErrSessionFinishedWithError, "wait until %v", @@ -321,8 +366,8 @@ func (r *repositoryImpl) StartQuizSession(ctx context.Context, userID UserID, la return nil, ErrSessionFinishedWithError } - if data.ActiveDeadline.After(now) { - return nil, errors.Wrapf(ErrSessionIsAlreadyRunning, "wait %s before next session", data.ActiveDeadline.Sub(now)) + if data.ActiveDeadline.After(*now.Time) { + return nil, errors.Wrapf(ErrSessionIsAlreadyRunning, "wait %s before next session", data.ActiveDeadline.Sub(*now.Time)) } case data.UpsertStartedAt != nil: // New session is started. @@ -338,6 +383,34 @@ func (r *repositoryImpl) StartQuizSession(ctx context.Context, userID UserID, la panic("unreachable: " + userID) } +func (r *repositoryImpl) StartQuizSession(ctx context.Context, userID UserID, lang string) (quiz *Quiz, err error) { + err = r.CheckUserKYC(ctx, userID) + if err != nil { + return nil, err + } + + questions, err := r.SelectQuestions(ctx, r.DB, lang) + if err != nil { + return nil, err + } + + err = storage.DoInTransaction(ctx, r.DB, func(tx storage.QueryExecer) error { + now := time.Now() + cooldown, fErr := r.finishExpiredSession(ctx, userID, now, tx) + if fErr != nil { + return wrapErrorInTx(fErr) + } else if cooldown != nil { + return wrapErrorInTx(errors.Wrapf(ErrSessionFinishedWithError, "wait until %v", cooldown)) + } + + quiz, err = r.startNewSession(ctx, userID, now, tx, lang, questions) + + return wrapErrorInTx(err) + }) + + return quiz, err +} + func calculateProgress(correctAnswers, currentAnswers []uint8) (correctNum, incorrectNum uint8) { correct := correctAnswers if len(currentAnswers) < len(correctAnswers) { diff --git a/kyc/quiz/quiz_test.go b/kyc/quiz/quiz_test.go index f6ff25da..39b49e8f 100644 --- a/kyc/quiz/quiz_test.go +++ b/kyc/quiz/quiz_test.go @@ -164,13 +164,8 @@ func testManagerSessionStart(ctx context.Context, t *testing.T, r *repositoryImp t.Run("Expired", func(t *testing.T) { helperForceResetSessionStartedAt(t, r, "bogus") session, err := r.StartQuizSession(ctx, "bogus", "en") - require.NoError(t, err) - require.NotNil(t, session) - require.NotNil(t, session.Progress) - require.NotNil(t, session.Progress.ExpiresAt) - require.NotEmpty(t, session.Progress.NextQuestion) - require.Equal(t, uint8(3), session.Progress.MaxQuestions) - require.Equal(t, uint8(1), session.Progress.NextQuestion.Number) + require.ErrorIs(t, err, ErrSessionFinishedWithError) + require.Nil(t, session) }) t.Run("CoolDown", func(t *testing.T) { helperSessionReset(t, r, "bogus", true)