diff --git a/cmd/eskimo-hut/contract.go b/cmd/eskimo-hut/contract.go index d0a84013..261a389f 100644 --- a/cmd/eskimo-hut/contract.go +++ b/cmd/eskimo-hut/contract.go @@ -136,7 +136,7 @@ type ( StartOrContinueKYCStep4SessionRequestBody struct { SelectedOption *uint8 `form:"selectedOption" required:"true" swaggerignore:"true" example:"0"` Language string `form:"language" required:"true" swaggerignore:"true" example:"en"` - QuestionNumber uint8 `form:"questionNumber" required:"true" swaggerignore:"true" example:"11"` + QuestionNumber uint `form:"questionNumber" required:"true" swaggerignore:"true" example:"11"` } VerifySocialKYCStepRequestBody struct { Social kycsocial.Type `form:"social" required:"true" swaggerignore:"true" example:"twitter"` diff --git a/cmd/eskimo-hut/kyc.go b/cmd/eskimo-hut/kyc.go index 756f36bc..4720c9a0 100644 --- a/cmd/eskimo-hut/kyc.go +++ b/cmd/eskimo-hut/kyc.go @@ -86,7 +86,7 @@ func (s *service) StartOrContinueKYCStep4Session( //nolint:gocritic,funlen,reviv fmt.Sprintf("[%v]Ice is not mined, but it turns out immediately after registration", strings.Repeat("bogus", rand.Intn(20))), //nolint:gosec,gomnd,lll // . fmt.Sprintf("[%v]Ice is cool", strings.Repeat("bogus", rand.Intn(20))), //nolint:gosec,gomnd // . }, - Number: uint8(11 + rand.Intn(20)), //nolint:gosec,gomnd // . + Number: uint(11 + rand.Intn(20)), //nolint:gosec,gomnd // . Text: fmt.Sprintf("[%v][%v] What are the major differences between Ice, Pi and Bee?", req.Data.Language, strings.Repeat("bogus", rand.Intn(20))), //nolint:gosec,gomnd,lll // . }, MaxQuestions: uint8(30 + rand.Intn(40)), //nolint:gosec,gomnd // . diff --git a/kyc/quiz/contract.go b/kyc/quiz/contract.go index 44e22303..b2afab7c 100644 --- a/kyc/quiz/contract.go +++ b/kyc/quiz/contract.go @@ -22,13 +22,18 @@ const ( ) type ( - UserID = users.UserID + UserID = users.UserID + UserProfile = users.UserProfile Manager interface { io.Closer Start(ctx context.Context, userID UserID, lang string, total int) (Quiz, error) - Continue(ctx context.Context, userID UserID, question, answer uint8) (Quiz, error) + Continue(ctx context.Context, userID UserID, question uint, answer uint8) (Quiz, error) + } + + UserReader interface { + GetUserByID(ctx context.Context, userID UserID) (*UserProfile, error) } Result string @@ -46,7 +51,7 @@ type ( Question struct { Text string `json:"text" example:"Какая температура на улице?" db:"question"` Options []string `json:"options" example:"+21,-2,+33,0" db:"options"` - Number uint8 `json:"number" example:"1" db:"id"` + Number uint `json:"number" example:"1" db:"id"` } ) @@ -59,6 +64,7 @@ var ( ErrSessionFinished = errors.New("session closed") ErrSessionExpired = errors.New("session expired") ErrUnknownQuestionNumber = errors.New("unknown question number") + ErrUnknownSession = errors.New("unknown session and/or user") ) const ( @@ -68,9 +74,11 @@ const ( defaultQuestionsCacheExpireSeconds = 60 * 60 * 24 // 1 day. ) -var ( //nolint:gofumpt //. +var ( //go:embed DDL.sql ddl string + + _ UserReader = users.ReadRepository(nil) ) type ( @@ -82,7 +90,7 @@ type ( managerImpl struct { DB *storage.DB Shutdown func() error - Users users.Repository + Users UserReader Cache struct { Questions map[string]quizDetails Timestamp *time.Time diff --git a/kyc/quiz/quiz.go b/kyc/quiz/quiz.go index 610f6008..2c8f200d 100644 --- a/kyc/quiz/quiz.go +++ b/kyc/quiz/quiz.go @@ -8,7 +8,6 @@ import ( "strconv" stdlibtime "time" - "github.com/hashicorp/go-multierror" "github.com/pkg/errors" "github.com/valyala/fastrand" @@ -34,13 +33,16 @@ func loadConfig() config { return cfg } -func NewManager(ctx context.Context) Manager { - repo := users.New(ctx, nil) +func NewManager(ctx context.Context, userReader UserReader) Manager { + return newManagerImpl(ctx, userReader) +} + +func newManagerImpl(ctx context.Context, userReader UserReader) *managerImpl { db := storage.MustConnect(ctx, ddl, applicationYamlKey) manager := &managerImpl{ DB: db, Shutdown: db.Close, - Users: repo, + Users: userReader, config: loadConfig(), } manager.Cache.Timestamp = time.New(stdlibtime.Time{}) @@ -136,15 +138,11 @@ func (m *managerImpl) Close() (err error) { err = m.Shutdown() } - if m.Users != nil { - err = multierror.Append(err, m.Users.Close()).ErrorOrNil() - } - return } -func questionsToSlice(questions []Question) []uint8 { - s := make([]uint8, 0, len(questions)) +func questionsToSlice(questions []Question) []uint { + s := make([]uint, 0, len(questions)) for i := range questions { s = append(s, questions[i].Number) } @@ -258,7 +256,7 @@ func (m *managerImpl) Start(ctx context.Context, userID UserID, lang string, tot return Quiz{ Progress: &Progress{ ExpiresAt: time.New(expires), - MaxQuestions: uint8(len(questions)), + MaxQuestions: uint8(len(questions)) - 1, NextQuestion: &Question{ Text: questions[0].Text, Options: questions[0].Options, @@ -277,13 +275,15 @@ func calculateProgress(correctAnswers, currentAnswers []int) (correctNum, incorr for i := range correct { if correct[i] == currentAnswers[i] { correctNum++ + } else { + incorrectNum++ } } - return correctNum, uint8(len(currentAnswers) - len(correct)) + return } -func (m *managerImpl) Continue(ctx context.Context, userID UserID, question, answer uint8) (Quiz, error) { //nolint:funlen // It's expected. +func (m *managerImpl) Continue(ctx context.Context, userID UserID, question uint, answer uint8) (Quiz, error) { //nolint:funlen // It's expected. // $1: user_id // $2: question // $3: answer @@ -309,15 +309,15 @@ func (m *managerImpl) Continue(ctx context.Context, userID UserID, question, ans ), corrent_answers as ( select - array_agg(questions.correct_option) as ids + array_agg(questions.correct_option order by q.nr) as ids from - questions, - session_data + session_data, + questions + inner join unnest(session_data.questions) with ordinality AS q(id, nr) + on questions.id = q.id where - session_data.valid is true and questions.language = session_data.language and - questions.quiz_id = session_data.quiz_id and - questions.id = ANY(ARRAY[session_data.questions]) + questions.quiz_id = session_data.quiz_id ), session_update as ( update quizz_sessions @@ -356,12 +356,11 @@ func (m *managerImpl) Continue(ctx context.Context, userID UserID, question, ans questions.id = session_data.questions[session_data.question_idx + 1] ) select - session_data.user_id, - COALESCE(session_data.expired, false) as expired, - COALESCE(session_data.expires_at, now()) as expires_at, - COALESCE(session_data.finished, false) as finished, - COALESCE(session_data.valid, false) as valid, - COALESCE(array_length(session_data.questions, 1) - session_data.question_idx + 1, 0) as left, + session_data.expired, + session_data.expires_at, + session_data.finished, + session_data.valid, + COALESCE((array_length(session_data.questions, 1) - session_data.question_idx) - 1, 0) as left, COALESCE(next_question.id, 0) as id, COALESCE(next_question.question, '') as question, @@ -374,12 +373,11 @@ func (m *managerImpl) Continue(ctx context.Context, userID UserID, question, ans COALESCE(session_update.answers, session_data.answers, '{}'::INT[]) as current_answers from session_data - full outer join next_question on true - full outer join session_update on true - full outer join corrent_answers on true + left join next_question on true + left join session_update on true + left join corrent_answers on true ` - type PipelineResult struct { //nolint:govet // Let's keep it as is for better readability. - User *string `db:"user_id"` + type PipelineResult struct { //nolint:govet // Let's keep it as is for better readability.` Expired bool `db:"expired"` ExpiresAt time.Time `db:"expires_at"` Finished bool `db:"finished"` @@ -396,20 +394,20 @@ func (m *managerImpl) Continue(ctx context.Context, userID UserID, question, ans result, err := storage.Get[PipelineResult](ctx, m.DB, pipeline, userID, int(question), int(answer), m.config.MaxSessionDurationSeconds) if err != nil { - if errors.As(err, &storage.ErrNotFound) { - return Quiz{}, ErrUnknownUser + if errors.Is(err, storage.ErrNotFound) { + return Quiz{}, ErrUnknownSession } return Quiz{}, errors.Wrap(err, "failed to continue session") } switch { - case result.User == nil: - return Quiz{}, ErrUnknownUser - case result.Expired && !result.Finished: return Quiz{}, ErrSessionExpired + case result.Valid && result.Finished && !result.Ended: + return Quiz{}, ErrSessionFinished + case !result.Valid: return Quiz{}, ErrUnknownQuestionNumber diff --git a/kyc/quiz/quiz_test.go b/kyc/quiz/quiz_test.go index 6ea14275..6df6b617 100644 --- a/kyc/quiz/quiz_test.go +++ b/kyc/quiz/quiz_test.go @@ -7,19 +7,293 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/ice-blockchain/eskimo/users" + "github.com/ice-blockchain/wintr/connectors/storage/v2" ) -func TestManagerUnknownUser(t *testing.T) { +func helperDeleteAllSessions(t *testing.T, m *managerImpl, userID UserID) { + t.Helper() + + _, err := storage.Exec(context.TODO(), m.DB, "DELETE FROM quizz_sessions WHERE user_id = $1", userID) + require.NoError(t, err) +} + +func helperForceFinishSession(t *testing.T, m *managerImpl, userID UserID) { + t.Helper() + + _, err := storage.Exec(context.TODO(), m.DB, "update quizz_sessions set ended_at = now() where user_id = $1", userID) + require.NoError(t, err) +} + +func helperForceResetSessionStartedAt(t *testing.T, m *managerImpl, userID UserID) { + t.Helper() + + _, err := storage.Exec(context.TODO(), m.DB, "update quizz_sessions set ended_at = NULL, started_at = to_timestamp(42) where user_id = $1", userID) + require.NoError(t, err) +} + +type mockUserReader struct{} + +func (*mockUserReader) GetUserByID(ctx context.Context, userID UserID) (*UserProfile, error) { + profile := &UserProfile{ + User: &users.User{}, + } + + switch userID { + case "bogus": + s := users.Social1KYCStep + profile.KYCStepPassed = &s + + case "invalid_kyc": + s := users.LivenessDetectionKYCStep + profile.KYCStepPassed = &s + + case "storage_error": + return nil, storage.ErrCheckFailed + + case "unknown_user": + return nil, storage.ErrNotFound + } + + return profile, nil +} + +func testManagerSessionStart(ctx context.Context, t *testing.T, m *managerImpl) { + helperDeleteAllSessions(t, m, "bogus") + + t.Run("UnknownUser", func(t *testing.T) { + _, err := m.Start(ctx, "unknown_user", "en", 1) + require.ErrorIs(t, err, ErrUnknownUser) + }) + + t.Run("UnknownLanguage", func(t *testing.T) { + _, err := m.Start(ctx, "bogus", "ff", 1) + require.ErrorIs(t, err, ErrUnknownLanguage) + }) + + t.Run("NotEnoughData", func(t *testing.T) { + _, err := m.Start(ctx, "bogus", "en", 0xff) + require.ErrorIs(t, err, ErrNotEnoughData) + }) + + t.Run("InvalidKYCState", func(t *testing.T) { + _, err := m.Start(ctx, "invalid_kyc", "en", 1) + require.ErrorIs(t, err, ErrInvalidKYCState) + }) + + t.Run("StorageError", func(t *testing.T) { + _, err := m.Start(ctx, "storage_error", "en", 1) + require.ErrorIs(t, err, storage.ErrCheckFailed) + }) + + t.Run("Sessions", func(t *testing.T) { + t.Run("OK", func(t *testing.T) { + session, err := m.Start(ctx, "bogus", "en", 2) + require.NoError(t, err) + require.NotNil(t, session) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.NotEmpty(t, session.Progress.NextQuestion) + require.Equal(t, uint8(1), session.Progress.MaxQuestions) + }) + t.Run("AlreadyExists", func(t *testing.T) { + _, err := m.Start(ctx, "bogus", "en", 1) + require.ErrorIs(t, err, ErrSessionIsAlreadyRunning) + }) + t.Run("Finished", func(t *testing.T) { + helperForceFinishSession(t, m, "bogus") + _, err := m.Start(ctx, "bogus", "en", 1) + require.ErrorIs(t, err, ErrSessionFinished) + }) + t.Run("Expired", func(t *testing.T) { + helperForceResetSessionStartedAt(t, m, "bogus") + session, err := m.Start(ctx, "bogus", "en", 1) + require.NoError(t, err) + require.NotNil(t, session) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.NotEmpty(t, session.Progress.NextQuestion) + require.Equal(t, uint8(0), session.Progress.MaxQuestions) + }) + }) +} + +func helperSolveQuestion(t *testing.T, text string) uint8 { + t.Helper() + + switch text { + case "What is the capital of France?": + return 1 + case "What is the capital of Spain?": + return 2 + case "What is the capital of Germany?": + return 3 + default: + t.Errorf("unknown question: %s", text) + } + + return 0 +} + +func testManagerSessionContinueErrors(ctx context.Context, t *testing.T, m *managerImpl) { + helperDeleteAllSessions(t, m, "bogus") + + t.Run("UnknownSession", func(t *testing.T) { + _, err := m.Continue(ctx, "unknown_user", 1, 1) + require.ErrorIs(t, err, ErrUnknownSession) + }) + + t.Run("Finished", func(t *testing.T) { + defer helperDeleteAllSessions(t, m, "bogus") + + data, err := m.Start(ctx, "bogus", "en", 1) + require.NoError(t, err) + helperForceFinishSession(t, m, "bogus") + _, err = m.Continue(ctx, "bogus", data.Progress.NextQuestion.Number, 1) + require.ErrorIs(t, err, ErrSessionFinished) + }) + + t.Run("Expired", func(t *testing.T) { + defer helperDeleteAllSessions(t, m, "bogus") + + data, err := m.Start(ctx, "bogus", "en", 1) + require.NoError(t, err) + helperForceResetSessionStartedAt(t, m, "bogus") + _, err = m.Continue(ctx, "bogus", data.Progress.NextQuestion.Number, 1) + require.ErrorIs(t, err, ErrSessionExpired) + }) + + t.Run("UnknownQuestionNumber", func(t *testing.T) { + defer helperDeleteAllSessions(t, m, "bogus") + + _, err := m.Start(ctx, "bogus", "en", 1) + require.NoError(t, err) + _, err = m.Continue(ctx, "bogus", 0, 1) + require.ErrorIs(t, err, ErrUnknownQuestionNumber) + }) +} + +func testManagerSessionContinueWithCorrectAnswers(ctx context.Context, t *testing.T, m *managerImpl) { + helperDeleteAllSessions(t, m, "bogus") + + session, err := m.Start(ctx, "bogus", "en", 3) + require.NoError(t, err) + require.NotNil(t, session) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.NotEmpty(t, session.Progress.NextQuestion) + require.Equal(t, uint8(2), session.Progress.MaxQuestions) + require.NotEmpty(t, session.Progress.NextQuestion.Text) + + ans := helperSolveQuestion(t, session.Progress.NextQuestion.Text) + t.Logf("q: %v, ans: %d", session.Progress.NextQuestion.Text, ans) + session, err = m.Continue(ctx, "bogus", session.Progress.NextQuestion.Number, ans) + require.NoError(t, err) + require.NotNil(t, session) + require.Empty(t, session.Result) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.Equal(t, uint8(1), session.Progress.MaxQuestions) + require.NotEmpty(t, session.Progress.NextQuestion.Text) + require.Equal(t, uint8(1), session.Progress.CorrectAnswers) + require.Equal(t, uint8(0), session.Progress.IncorrectAnswers) + + ans = helperSolveQuestion(t, session.Progress.NextQuestion.Text) + t.Logf("q: %v, ans: %d", session.Progress.NextQuestion.Text, ans) + session, err = m.Continue(ctx, "bogus", session.Progress.NextQuestion.Number, ans) + require.NoError(t, err) + require.NotNil(t, session) + require.Empty(t, session.Result) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.Equal(t, uint8(0), session.Progress.MaxQuestions) + require.NotEmpty(t, session.Progress.NextQuestion.Text) + require.Equal(t, uint8(2), session.Progress.CorrectAnswers) + require.Equal(t, uint8(0), session.Progress.IncorrectAnswers) + + ans = helperSolveQuestion(t, session.Progress.NextQuestion.Text) + t.Logf("q: %v, ans: %d", session.Progress.NextQuestion.Text, ans) + session, err = m.Continue(ctx, "bogus", session.Progress.NextQuestion.Number, ans) + require.NoError(t, err) + require.NotNil(t, session) + require.Equal(t, SuccessResult, session.Result) + require.NotNil(t, session.Progress) + require.Nil(t, session.Progress.NextQuestion) + require.Equal(t, uint8(0), session.Progress.MaxQuestions) + require.Equal(t, uint8(3), session.Progress.CorrectAnswers) + require.Equal(t, uint8(0), session.Progress.IncorrectAnswers) +} + +func testManagerSessionContinueWithIncorrectAnswers(ctx context.Context, t *testing.T, m *managerImpl) { + helperDeleteAllSessions(t, m, "bogus") + + session, err := m.Start(ctx, "bogus", "en", 3) + require.NoError(t, err) + require.NotNil(t, session) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.NotEmpty(t, session.Progress.NextQuestion) + require.Equal(t, uint8(2), session.Progress.MaxQuestions) + require.NotEmpty(t, session.Progress.NextQuestion.Text) + + ans := helperSolveQuestion(t, session.Progress.NextQuestion.Text) + t.Logf("q: %v, ans: %d", session.Progress.NextQuestion.Text, ans) + session, err = m.Continue(ctx, "bogus", session.Progress.NextQuestion.Number, ans) + require.NoError(t, err) + require.NotNil(t, session) + require.Empty(t, session.Result) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.Equal(t, uint8(1), session.Progress.MaxQuestions) + require.NotEmpty(t, session.Progress.NextQuestion.Text) + require.Equal(t, uint8(1), session.Progress.CorrectAnswers) + require.Equal(t, uint8(0), session.Progress.IncorrectAnswers) + + ans = helperSolveQuestion(t, session.Progress.NextQuestion.Text) + t.Logf("q: %v, ans: %d", session.Progress.NextQuestion.Text, ans) + session, err = m.Continue(ctx, "bogus", session.Progress.NextQuestion.Number, ans) + require.NoError(t, err) + require.NotNil(t, session) + require.Empty(t, session.Result) + require.NotNil(t, session.Progress) + require.NotNil(t, session.Progress.ExpiresAt) + require.Equal(t, uint8(0), session.Progress.MaxQuestions) + require.NotEmpty(t, session.Progress.NextQuestion.Text) + require.Equal(t, uint8(2), session.Progress.CorrectAnswers) + require.Equal(t, uint8(0), session.Progress.IncorrectAnswers) + + session, err = m.Continue(ctx, "bogus", session.Progress.NextQuestion.Number, 0) + require.NoError(t, err) + require.NotNil(t, session) + require.Equal(t, FailureResult, session.Result) + require.NotNil(t, session.Progress) + require.Nil(t, session.Progress.NextQuestion) + require.Equal(t, uint8(0), session.Progress.MaxQuestions) + require.Equal(t, uint8(2), session.Progress.CorrectAnswers) + require.Equal(t, uint8(1), session.Progress.IncorrectAnswers) +} + +func TestSessionManager(t *testing.T) { t.Parallel() - m := NewManager(context.TODO()) - require.NotNil(t, m) + ctx := context.TODO() + + // Create user repo because we need its schema. + r := users.New(ctx, nil) + require.NotNil(t, r) - _, err := m.Start(context.Background(), "fooo", "en", 2) - require.ErrorIs(t, err, ErrUnknownUser) + m := newManagerImpl(ctx, new(mockUserReader)) + require.NotNil(t, m) - _, err = m.Continue(context.Background(), "fooo", 1, 1) - require.ErrorIs(t, err, ErrUnknownUser) + t.Run("Start", func(t *testing.T) { + testManagerSessionStart(ctx, t, m) + }) + t.Run("Continue", func(t *testing.T) { + testManagerSessionContinueErrors(ctx, t, m) + testManagerSessionContinueWithCorrectAnswers(ctx, t, m) + testManagerSessionContinueWithIncorrectAnswers(ctx, t, m) + }) require.NoError(t, m.Close()) }