Skip to content

Commit

Permalink
correctly identify ca-cert vs ca-ra-cert
Browse files Browse the repository at this point in the history
  • Loading branch information
groob committed Jun 9, 2016
1 parent fc62c53 commit c15d1c1
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 20 deletions.
6 changes: 3 additions & 3 deletions client/scep.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ func (c *client) GetCACaps(ctx context.Context) ([]byte, error) {
return r.Data, nil
}

func (c *client) GetCACert(ctx context.Context) ([]byte, error) {
func (c *client) GetCACert(ctx context.Context) ([]byte, int, error) {
request := scepserver.SCEPRequest{
Operation: "GetCACert",
}
reply, err := c.getRemote(ctx, request)
if err != nil {
return nil, err
return nil, 0, err
}
r := reply.(scepserver.SCEPResponse)
return r.Data, nil
return r.Data, r.CACertNum, nil
}

func (c *client) PKIOperation(ctx context.Context, data []byte) ([]byte, error) {
Expand Down
18 changes: 14 additions & 4 deletions cmd/scepclient/scepclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,23 @@ func run(cfg runCfg) error {
client = scepclient.NewClient(cfg.serverURL)
}

resp, err := client.GetCACert(ctx)
resp, certNum, err := client.GetCACert(ctx)
if err != nil {
return err
}
certs, err := scep.CACerts(resp)
if err != nil {
return err
var certs []*x509.Certificate
{
if certNum > 1 {
certs, err = scep.CACerts(resp)
if err != nil {
return err
}
} else {
certs, err = x509.ParseCertificates(resp)
if err != nil {
return err
}
}
}

var signerCert *x509.Certificate
Expand Down
5 changes: 3 additions & 2 deletions server/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type SCEPRequest struct {
// Business errors will be encoded as a CertRep message
// with pkiStatus FAILURE and a failInfo attribute.
type SCEPResponse struct {
Data []byte
Err error // response error
CACertNum int //chain
Data []byte
Err error // response error
}
12 changes: 8 additions & 4 deletions server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Service interface {
// GetCACert returns CA certificate or
// a CA certificate chain with intermediates
// in a PKCS#7 Degenerate Certificates format
GetCACert(ctx context.Context) ([]byte, error)
GetCACert(ctx context.Context) ([]byte, int, error)

// PKIOperation handles incoming SCEP messages such as PKCSReq and
// sends back a CertRep PKIMessag.
Expand All @@ -50,11 +50,15 @@ func (svc service) GetCACaps(ctx context.Context) ([]byte, error) {
return defaultCaps, nil
}

func (svc service) GetCACert(ctx context.Context) ([]byte, error) {
func (svc service) GetCACert(ctx context.Context) ([]byte, int, error) {
if len(svc.ca) == 0 {
return nil, errors.New("missing CA Cert")
return nil, 0, errors.New("missing CA Cert")
}
return scep.DegenerateCertificates(svc.ca)
if len(svc.ca) == 1 {
return svc.ca[0].Raw, 1, nil
}
data, err := scep.DegenerateCertificates(svc.ca)
return data, len(svc.ca), err
}

func (svc service) PKIOperation(ctx context.Context, data []byte) ([]byte, error) {
Expand Down
4 changes: 2 additions & 2 deletions server/service_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ func (mw loggingService) GetCACaps(ctx context.Context) (caps []byte, err error)
return
}

func (mw loggingService) GetCACert(ctx context.Context) (cert []byte, err error) {
func (mw loggingService) GetCACert(ctx context.Context) (cert []byte, certNum int, err error) {
defer func(begin time.Time) {
_ = mw.logger.Log(
"method", "GetCACert",
"err", err,
"took", time.Since(begin),
)
}(time.Now())
cert, err = mw.Service.GetCACert(ctx)
cert, certNum, err = mw.Service.GetCACert(ctx)
return
}

Expand Down
20 changes: 15 additions & 5 deletions server/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func encodeSCEPResponse(ctx context.Context, w http.ResponseWriter, response int
fmt.Println(resp.Err)
return resp.Err
}
w.Header().Set("Content-Type", contentHeader(ctx))
w.Header().Set("Content-Type", contentHeader(ctx, resp.CACertNum))
w.Write(resp.Data)
return nil
}
Expand All @@ -118,19 +118,29 @@ func DecodeSCEPResponse(ctx context.Context, r *http.Response) (interface{}, err
resp := SCEPResponse{
Data: data,
}
header := r.Header.Get("Content-Type")
if header == certChainHeader {
// TODO decode the response instead of just passing []byte around
// 0 or 1
resp.CACertNum = 2
}
return resp, nil
}

const (
certChainHeader = "application/x-x509-ca-ra-cert"
leafHeader = "application/x-x509-ca-cert"
pkiOpHeader = "application/x-pki-message"
)

func contentHeader(ctx context.Context) string {
func contentHeader(ctx context.Context, certNum int) string {
op := ctx.Value("operation")
switch op {
case "GetCACert":
return certChainHeader
if certNum > 1 {
return certChainHeader
}
return leafHeader
case "PKIOperation":
return pkiOpHeader
default:
Expand All @@ -153,9 +163,9 @@ func makeSCEPEndpoint(svc Service) endpoint.Endpoint {
}
return SCEPResponse{Data: caps}, nil
case "GetCACert":
cert, err := svc.GetCACert(ctx)
cert, certNum, err := svc.GetCACert(ctx)
if err != nil {
return SCEPResponse{Err: err}, nil
return SCEPResponse{Err: err, CACertNum: certNum}, nil
}
return SCEPResponse{Data: cert}, nil
case "PKIOperation":
Expand Down

0 comments on commit c15d1c1

Please sign in to comment.