diff --git a/gapi/error.go b/gapi/error.go new file mode 100644 index 0000000..b3b6db5 --- /dev/null +++ b/gapi/error.go @@ -0,0 +1,30 @@ +package gapi + +import ( + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func fieldViolation(field string, err error) *errdetails.BadRequest_FieldViolation { + return &errdetails.BadRequest_FieldViolation{ + Field: field, + Description: err.Error(), + } +} + +func invalidArgumentError(violations []*errdetails.BadRequest_FieldViolation) error { + badRequest := &errdetails.BadRequest{FieldViolations: violations} + statusInvalid := status.New(codes.InvalidArgument, "invalid parameters") + + statusDetails, err := statusInvalid.WithDetails(badRequest) + if err != nil { + return statusInvalid.Err() + } + + return statusDetails.Err() +} + +func unauthenticatedError(err error) error { + return status.Errorf(codes.Unauthenticated, "unauthorized: %s", err) +} diff --git a/gapi/rpc_create_user.go b/gapi/rpc_create_user.go index eb53e0b..7817ee3 100644 --- a/gapi/rpc_create_user.go +++ b/gapi/rpc_create_user.go @@ -6,12 +6,18 @@ import ( db "github.com/eizyc/simplebank/db/sqlc" "github.com/eizyc/simplebank/pb" "github.com/eizyc/simplebank/util" + "github.com/eizyc/simplebank/val" "github.com/jackc/pgx/v5/pgconn" + "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func (server *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (*pb.CreateUserResponse, error) { + violations := validateCreateUserRequest(req) + if violations != nil { + return nil, invalidArgumentError(violations) + } hashedPassword, err := util.HashPassword(req.GetPassword()) if err != nil { return nil, status.Errorf(codes.Internal, "failed to hash password: %s", err) @@ -42,3 +48,23 @@ func (server *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest) } return rsp, nil } + +func validateCreateUserRequest(req *pb.CreateUserRequest) (violations []*errdetails.BadRequest_FieldViolation) { + if err := val.ValidateUsername(req.GetUsername()); err != nil { + violations = append(violations, fieldViolation("username", err)) + } + + if err := val.ValidatePassword(req.GetPassword()); err != nil { + violations = append(violations, fieldViolation("password", err)) + } + + if err := val.ValidateFullName(req.GetFullName()); err != nil { + violations = append(violations, fieldViolation("full_name", err)) + } + + if err := val.ValidateEmail(req.GetEmail()); err != nil { + violations = append(violations, fieldViolation("email", err)) + } + + return violations +} diff --git a/gapi/rpc_login_user.go b/gapi/rpc_login_user.go index bc80a2d..82c5f44 100644 --- a/gapi/rpc_login_user.go +++ b/gapi/rpc_login_user.go @@ -7,12 +7,19 @@ import ( db "github.com/eizyc/simplebank/db/sqlc" "github.com/eizyc/simplebank/pb" "github.com/eizyc/simplebank/util" + "github.com/eizyc/simplebank/val" + "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" ) func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) (*pb.LoginUserResponse, error) { + violations := validateLoginUserRequest(req) + if violations != nil { + return nil, invalidArgumentError(violations) + } + user, err := server.store.GetUser(ctx, req.GetUsername()) if err != nil { if err == sql.ErrNoRows { @@ -66,3 +73,15 @@ func (server *Server) LoginUser(ctx context.Context, req *pb.LoginUserRequest) ( } return rsp, nil } + +func validateLoginUserRequest(req *pb.LoginUserRequest) (violations []*errdetails.BadRequest_FieldViolation) { + if err := val.ValidateUsername(req.GetUsername()); err != nil { + violations = append(violations, fieldViolation("username", err)) + } + + if err := val.ValidatePassword(req.GetPassword()); err != nil { + violations = append(violations, fieldViolation("password", err)) + } + + return violations +} diff --git a/val/validator.go b/val/validator.go new file mode 100644 index 0000000..83aa672 --- /dev/null +++ b/val/validator.go @@ -0,0 +1,65 @@ +package val + +import ( + "fmt" + "net/mail" + "regexp" +) + +var ( + isValidUsername = regexp.MustCompile(`^[a-z0-9_]+$`).MatchString + isValidFullName = regexp.MustCompile(`^[a-zA-Z\s]+$`).MatchString +) + +func ValidateString(value string, minLength int, maxLength int) error { + n := len(value) + if n < minLength || n > maxLength { + return fmt.Errorf("must contain from %d-%d characters", minLength, maxLength) + } + return nil +} + +func ValidateUsername(value string) error { + if err := ValidateString(value, 3, 100); err != nil { + return err + } + if !isValidUsername(value) { + return fmt.Errorf("must contain only lowercase letters, digits, or underscore") + } + return nil +} + +func ValidateFullName(value string) error { + if err := ValidateString(value, 3, 100); err != nil { + return err + } + if !isValidFullName(value) { + return fmt.Errorf("must contain only letters or spaces") + } + return nil +} + +func ValidatePassword(value string) error { + return ValidateString(value, 6, 100) +} + +func ValidateEmail(value string) error { + if err := ValidateString(value, 3, 200); err != nil { + return err + } + if _, err := mail.ParseAddress(value); err != nil { + return fmt.Errorf("is not a valid email address") + } + return nil +} + +func ValidateEmailId(value int64) error { + if value <= 0 { + return fmt.Errorf("must be a positive integer") + } + return nil +} + +func ValidateSecretCode(value string) error { + return ValidateString(value, 32, 128) +}