diff --git a/kyc/quiz/quiz.go b/kyc/quiz/quiz.go index d7841aa2..b5e3cb3a 100644 --- a/kyc/quiz/quiz.go +++ b/kyc/quiz/quiz.go @@ -97,89 +97,59 @@ func (r *repositoryImpl) validateKycStep(user *users.User) error { return nil } -func (r *repositoryImpl) SkipQuizSession(ctx context.Context, userID UserID) error { - if err := r.CheckUserKYC(ctx, userID); err != nil { - return err - } - - now := time.Now() - for _, fn := range []func(context.Context, UserID, stdlibtime.Time, storage.QueryExecer) error{ - r.CheckUserFailedSession, - r.CheckUserActiveSession, - } { - if err := fn(ctx, userID, *now.Time, r.DB); err != nil { - return err - } - } - - return errors.Wrapf(r.UserMarkSessionAsFinished(ctx, userID, *now.Time, r.DB, false, true), - "failed to UserMarkSessionAsFinished for userID:%v", userID) -} - -func (r *repositoryImpl) CheckUserFailedSession(ctx context.Context, userID UserID, now stdlibtime.Time, tx storage.QueryExecer) error { - type failedSession struct { - EndedAt stdlibtime.Time `db:"ended_at"` - } - +func (r *repositoryImpl) SkipQuizSession(ctx context.Context, userID UserID) error { //nolint:funlen //. + // $1: user_id. const stmt = ` -select max(ended_at) as ended_at from failed_quiz_sessions where user_id = $1 having max(ended_at) > $2 + select + started_at, + ended_at is not null as finished, + ended_successfully + from + quiz_sessions + where + user_id = $1 + for update ` - term := now. - Add(stdlibtime.Duration(-r.config.SessionCoolDownSeconds) * stdlibtime.Second). - Truncate(stdlibtime.Second) - data, err := storage.Get[failedSession](ctx, tx, stmt, userID, term) - if err != nil { - if errors.Is(err, storage.ErrNotFound) { - return nil - } - - return errors.Wrap(err, "failed to get failed session data") + if err := r.CheckUserKYC(ctx, userID); err != nil { + return err } - next := data.EndedAt. - Add(stdlibtime.Duration(r.config.SessionCoolDownSeconds) * stdlibtime.Second). - Truncate(stdlibtime.Second). - UTC() + err := storage.DoInTransaction(ctx, r.DB, func(tx storage.QueryExecer) error { + now := time.Now() - return errors.Wrapf(ErrSessionFinishedWithError, "wait until %v", next) -} - -func (r *repositoryImpl) CheckUserActiveSession(ctx context.Context, userID UserID, now stdlibtime.Time, tx storage.QueryExecer) error { - type userSession struct { - StartedAt time.Time `db:"started_at"` - EndedAt *time.Time `db:"ended_at"` - Finished bool `db:"finished"` - FinishedSuccessfully bool `db:"ended_successfully"` - } - const stmt = `select started_at, ended_at, ended_at is not null as finished, ended_successfully from quiz_sessions where user_id = $1` + data, err := storage.Get[struct { + StartedAt *time.Time `db:"started_at"` + Finished bool `db:"finished"` + Success bool `db:"ended_successfully"` + }](ctx, tx, stmt, userID) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + return wrapErrorInTx(ErrUnknownSession) + } - data, err := storage.Get[userSession](ctx, tx, stmt, userID) - if err != nil { - if errors.Is(err, storage.ErrNotFound) { - return nil + return errors.Wrap(wrapErrorInTx(err), "failed to get session data") } - return errors.Wrap(err, "failed to get active session data") - } + switch { + case data.StartedAt == nil: + return wrapErrorInTx(ErrUnknownSession) - if data.Finished { - if data.FinishedSuccessfully { - return ErrSessionFinished - } + case data.StartedAt.Add(stdlibtime.Duration(r.config.MaxSessionDurationSeconds) * stdlibtime.Second).Before(*now.Time): + return wrapErrorInTx(ErrSessionExpired) + + case data.Finished: + if data.Success { + return wrapErrorInTx(ErrSessionFinished) + } - cooldown := data.EndedAt.Add(stdlibtime.Duration(r.config.SessionCoolDownSeconds) * stdlibtime.Second) - if cooldown.After(now) { - return ErrSessionFinishedWithError + return wrapErrorInTx(ErrSessionFinishedWithError) } - } - deadline := data.StartedAt.Add(stdlibtime.Duration(r.config.MaxSessionDurationSeconds) * stdlibtime.Second) - if deadline.After(now) { - return ErrSessionIsAlreadyRunning - } + return wrapErrorInTx(r.UserMarkSessionAsFinished(ctx, userID, *now.Time, tx, false, true)) + }) - return nil + return errors.Wrap(err, "failed to skip session") } func (r *repositoryImpl) SelectQuestions(ctx context.Context, tx storage.QueryExecer, lang string) ([]*Question, error) { @@ -215,35 +185,6 @@ func questionsToSlice(questions []*Question) []uint { return result } -func (*repositoryImpl) CreateSessionEntry( //nolint:revive //. - ctx context.Context, - userID UserID, - lang string, - questions []*Question, - now stdlibtime.Time, - tx storage.QueryExecer, -) error { - const stmt = ` -insert into quiz_sessions (user_id, language, questions, started_at, answers) values ($1, $2, $3, $4, '{}'::smallint[]) - on conflict on constraint quiz_sessions_pkey do update - set - started_at = excluded.started_at, - questions = excluded.questions, - answers = excluded.answers, - language = excluded.language, - ended_successfully = false - ` - - _, err := storage.Exec(ctx, tx, stmt, userID, lang, questionsToSlice(questions), now) - if err != nil { - if errors.Is(err, storage.ErrRelationNotFound) { - err = ErrUnknownUser - } - } - - return errors.Wrap(err, "failed to create session entry") -} - func wrapErrorInTx(err error) error { if err == nil { return nil @@ -551,7 +492,7 @@ select result.answers, result.language, result.user_id, - $4 AS skipped + $4 AS skipped from result where result.ended_successfully = false diff --git a/kyc/quiz/quiz_test.go b/kyc/quiz/quiz_test.go index 48ae1886..e3aac2be 100644 --- a/kyc/quiz/quiz_test.go +++ b/kyc/quiz/quiz_test.go @@ -185,7 +185,63 @@ func testManagerSessionStart(ctx context.Context, t *testing.T, r *repositoryImp require.ErrorIs(t, err, ErrSessionFinishedWithError) }) }) +} + +func testManagerSessionSkip(ctx context.Context, t *testing.T, r *repositoryImpl) { + t.Run("OK", func(t *testing.T) { + helperSessionReset(t, r, "bogus", true) + + _, err := r.StartQuizSession(ctx, "bogus", "en") + require.NoError(t, err) + + err = r.SkipQuizSession(ctx, "bogus") + require.NoError(t, err) + + _, err = r.StartQuizSession(ctx, "bogus", "en") + require.ErrorIs(t, err, ErrSessionFinishedWithError) + }) + t.Run("UnknownSession", func(t *testing.T) { + helperSessionReset(t, r, "bogus", true) + + err := r.SkipQuizSession(ctx, "bogus") + require.ErrorIs(t, err, ErrUnknownSession) + }) + t.Run("Expired", func(t *testing.T) { + helperSessionReset(t, r, "bogus", true) + + _, err := r.StartQuizSession(ctx, "bogus", "en") + require.NoError(t, err) + + helperForceResetSessionStartedAt(t, r, "bogus") + + err = r.SkipQuizSession(ctx, "bogus") + require.ErrorIs(t, err, ErrSessionExpired) + }) + t.Run("Finished", func(t *testing.T) { + t.Run("Success", func(t *testing.T) { + helperSessionReset(t, r, "bogus", true) + + _, err := r.StartQuizSession(ctx, "bogus", "en") + require.NoError(t, err) + + helperForceFinishSession(t, r, "bogus", true) + + err = r.SkipQuizSession(ctx, "bogus") + require.ErrorIs(t, err, ErrSessionFinished) + }) + t.Run("Error", func(t *testing.T) { + helperSessionReset(t, r, "bogus", true) + + _, err := r.StartQuizSession(ctx, "bogus", "en") + require.NoError(t, err) + + err = r.SkipQuizSession(ctx, "bogus") + require.NoError(t, err) + err = r.SkipQuizSession(ctx, "bogus") + require.ErrorIs(t, err, ErrSessionFinishedWithError) + }) + }) } func testManagerSessionContinueErrors(ctx context.Context, t *testing.T, r *repositoryImpl) { @@ -371,5 +427,9 @@ func TestSessionManager(t *testing.T) { testManagerSessionContinueWithIncorrectAnswers(ctx, t, repo) }) + t.Run("Skip", func(t *testing.T) { + testManagerSessionSkip(ctx, t, repo) + }) + require.NoError(t, repo.Close()) }