From 3d02ce84da97514bae24a718043e244010329610 Mon Sep 17 00:00:00 2001 From: aftermath2 Date: Tue, 26 Dec 2023 22:55:57 -0300 Subject: [PATCH] Set response fields when evaluating the policy --- main.go | 15 ++------------- policy/policy.go | 32 ++++++++++++++++++++++---------- policy/policy_test.go | 39 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index de4761f..936235d 100644 --- a/main.go +++ b/main.go @@ -105,7 +105,7 @@ func handleRequest( node, err := client.GetInfo(ctx, &lnrpc.GetInfoRequest{}) if err != nil { - return resp, errors.Wrap(err, "getting node information") + return resp, errors.New("Internal server error") } getPeerInfoReq := &lnrpc.NodeInfoRequest{ @@ -119,20 +119,9 @@ func handleRequest( slog.Debug("Peer node information", slog.Any("node", peer)) for _, policy := range config.Policies { - if err := policy.Evaluate(req, node, peer); err != nil { + if err := policy.Evaluate(req, resp, node, peer); err != nil { return resp, err } - - if policy.MinAcceptDepth != nil { - resp.MinAcceptDepth = *policy.MinAcceptDepth - } - } - - if req.WantsZeroConf && len(config.Policies) != 0 { - // The initiator requested a zero conf channel and it was explicitly accepted, set the - // fields required to open it - resp.ZeroConf = true - resp.MinAcceptDepth = 0 } return resp, nil diff --git a/policy/policy.go b/policy/policy.go index 18b2abb..c94ebac 100644 --- a/policy/policy.go +++ b/policy/policy.go @@ -27,6 +27,7 @@ type Policy struct { // Evaluate set of policies. func (p *Policy) Evaluate( req *lnrpc.ChannelAcceptRequest, + resp *lnrpc.ChannelAcceptResponse, node *lnrpc.GetInfoResponse, peer *lnrpc.NodeInfo, ) error { @@ -34,6 +35,10 @@ func (p *Policy) Evaluate( return nil } + if p.MinAcceptDepth != nil { + resp.MinAcceptDepth = *p.MinAcceptDepth + } + if !p.checkRejectAll() { return errors.New("No new channels are accepted") } @@ -50,7 +55,7 @@ func (p *Policy) Evaluate( return errors.New("Private channels are not accepted") } - if !p.checkZeroConf(peer.Node.PubKey, req.WantsZeroConf) { + if !p.checkZeroConf(peer.Node.PubKey, req.WantsZeroConf, resp) { return errors.New("Zero conf channels are not accepted") } @@ -101,7 +106,11 @@ func (p *Policy) checkPrivate(private bool) bool { return private && !*p.RejectPrivateChannels } -func (p *Policy) checkZeroConf(publicKey string, wantsZeroConf bool) bool { +func (p *Policy) checkZeroConf( + publicKey string, + wantsZeroConf bool, + resp *lnrpc.ChannelAcceptResponse, +) bool { if !wantsZeroConf { return true } @@ -110,15 +119,18 @@ func (p *Policy) checkZeroConf(publicKey string, wantsZeroConf bool) bool { return false } - if p.ZeroConfList != nil { - for _, pubKey := range *p.ZeroConfList { - if publicKey == pubKey { - return true - } - } + resp.ZeroConf = true + resp.MinAcceptDepth = 0 - return false + if p.ZeroConfList == nil { + return true } - return true + for _, pubKey := range *p.ZeroConfList { + if publicKey == pubKey { + return true + } + } + + return false } diff --git a/policy/policy_test.go b/policy/policy_test.go index 59b28a1..cd608e8 100644 --- a/policy/policy_test.go +++ b/policy/policy_test.go @@ -19,6 +19,7 @@ func TestEvaluatePolicy(t *testing.T) { tru := true fals := false max := uint64(1) + depth := uint32(10) cases := []struct { policy Policy @@ -138,11 +139,20 @@ func TestEvaluatePolicy(t *testing.T) { }, fail: true, }, + { + desc: "Min accept depth", + policy: Policy{ + MinAcceptDepth: &depth, + }, + req: defaultReq, + peer: defaultPeer, + fail: false, + }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - err := tc.policy.Evaluate(tc.req, node, tc.peer) + err := tc.policy.Evaluate(tc.req, &lnrpc.ChannelAcceptResponse{}, node, tc.peer) if tc.fail { assert.NotNil(t, err) } else { @@ -152,6 +162,25 @@ func TestEvaluatePolicy(t *testing.T) { } } +func TestMinAcceptDepth(t *testing.T) { + n := uint32(2) + policy := Policy{ + MinAcceptDepth: &n, + } + resp := &lnrpc.ChannelAcceptResponse{} + node := &lnrpc.NodeInfo{Node: &lnrpc.LightningNode{PubKey: ""}} + + err := policy.Evaluate( + &lnrpc.ChannelAcceptRequest{}, + resp, + &lnrpc.GetInfoResponse{}, + node, + ) + assert.NoError(t, err) + + assert.Equal(t, n, resp.MinAcceptDepth) +} + func TestCheckRejectAll(t *testing.T) { cases := []struct { desc string @@ -373,8 +402,14 @@ func TestCheckZeroConf(t *testing.T) { ZeroConfList: tc.zeroConfList, } - actual := policy.checkZeroConf(tc.publicKey, tc.wantsZeroConf) + resp := &lnrpc.ChannelAcceptResponse{} + actual := policy.checkZeroConf(tc.publicKey, tc.wantsZeroConf, resp) assert.Equal(t, tc.expected, actual) + + if tc.wantsZeroConf && tc.expected { + assert.True(t, resp.ZeroConf) + assert.Zero(t, resp.MinAcceptDepth) + } }) } }