Skip to content

Commit

Permalink
Merge pull request #89 from sonroyaalmerol/improve-errors
Browse files Browse the repository at this point in the history
iron out some bugs with websockets and improve store mutex locking
  • Loading branch information
sonroyaalmerol authored Jan 28, 2025
2 parents 49feca4 + 1773b54 commit c390a6c
Show file tree
Hide file tree
Showing 30 changed files with 1,269 additions and 797 deletions.
1 change: 1 addition & 0 deletions cmd/pbs_plus/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ func main() {

// Agent auth routes
mux.HandleFunc("/plus/agent/bootstrap", mw.CORS(storeInstance, agents.AgentBootstrapHandler(storeInstance)))
mux.HandleFunc("/plus/agent/renew", mw.AgentOnly(storeInstance, mw.CORS(storeInstance, agents.AgentRenewHandler(storeInstance))))

server := &http.Server{
Addr: serverConfig.Address,
Expand Down
3 changes: 3 additions & 0 deletions cmd/windows_agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
_ "net/http/pprof"

"github.com/kardianos/service"
"github.com/sonroyaalmerol/pbs-plus/internal/store/constants"
"github.com/sonroyaalmerol/pbs-plus/internal/syslog"
"golang.org/x/sys/windows/registry"
)
Expand Down Expand Up @@ -90,6 +91,8 @@ func (w *watchdogService) Stop(s service.Service) error {
}

func main() {
constants.Version = Version

go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
Expand Down
50 changes: 25 additions & 25 deletions cmd/windows_agent/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,25 @@ func (p *agentService) Start(s service.Service) error {
p.svc = s
p.ctx, p.cancel = context.WithCancel(context.Background())

p.wg.Add(1)
p.wg.Add(2)
go func() {
defer p.wg.Done()
p.run()
}()
go func() {
defer p.wg.Done()
for {
select {
case <-p.ctx.Done():
return
case <-time.After(time.Hour):
err := agent.CheckAndRenewCertificate()
if err != nil {
syslog.L.Errorf("Certificate renewal manager: %w", err)
}
}
}
}()

return nil
}
Expand Down Expand Up @@ -110,7 +124,7 @@ func (p *agentService) waitForServerURL() error {
}

func (p *agentService) waitForBootstrap() error {
ticker := time.NewTicker(5 * time.Second)
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()

for {
Expand All @@ -119,10 +133,16 @@ func (p *agentService) waitForBootstrap() error {
priv, _ := registry.GetEntry(registry.AUTH, "Priv", true)

if serverCA != nil && cert != nil && priv != nil {
return nil
err := agent.CheckAndRenewCertificate()
if err == nil {
return nil
}
syslog.L.Errorf("Renewal error: %v", err)
} else {
err := agent.Bootstrap()
syslog.L.Errorf("Bootstrap error: %v", err)
if err != nil {
syslog.L.Errorf("Bootstrap error: %v", err)
}
}

select {
Expand Down Expand Up @@ -201,27 +221,7 @@ func (p *agentService) connectWebSocket() error {
return err
}

client, err := websockets.NewWSClient(p.ctx, config, tlsConfig)
if err != nil {
syslog.L.Errorf("WS client init error: %s", err)
select {
case <-p.ctx.Done():
return fmt.Errorf("context cancelled while connecting to WebSocket")
case <-time.After(5 * time.Second):
continue
}
}

err = client.Connect()
if err != nil {
syslog.L.Errorf("WS client connect error: %s", err)
select {
case <-p.ctx.Done():
return fmt.Errorf("context cancelled while connecting to WebSocket")
case <-time.After(5 * time.Second):
continue
}
}
client := websockets.NewWSClient(p.ctx, config, tlsConfig)

client.RegisterHandler("backup_start", controllers.BackupStartHandler(client))
client.RegisterHandler("backup_close", controllers.BackupCloseHandler(client))
Expand Down
2 changes: 2 additions & 0 deletions internal/agent/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/sonroyaalmerol/pbs-plus/internal/agent/registry"
"github.com/sonroyaalmerol/pbs-plus/internal/store/constants"
)

var httpClient *http.Client
Expand Down Expand Up @@ -40,6 +41,7 @@ func ProxmoxHTTPRequest(method, url string, body io.Reader, respBody any) (io.Re

req.Header.Add("Content-Type", "application/json")
req.Header.Add("X-PBS-Agent", hostname)
req.Header.Add("X-PBS-Plus-Version", constants.Version)

tlsConfig, err := GetTLSConfig()
if err != nil {
Expand Down
4 changes: 1 addition & 3 deletions internal/agent/status_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// BackupStatus manages the state of ongoing backups
type BackupStatus struct {
activeBackups sync.Map
mu sync.RWMutex
mu sync.RWMutex
}

var (
Expand Down Expand Up @@ -44,5 +44,3 @@ func (bs *BackupStatus) HasActiveBackups() bool {
})
return hasActive
}


122 changes: 122 additions & 0 deletions internal/agent/tls_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
package agent

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"os"
"time"

"github.com/sonroyaalmerol/pbs-plus/internal/agent/registry"
"github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates"
"github.com/sonroyaalmerol/pbs-plus/internal/utils"
)

func GetTLSConfig() (*tls.Config, error) {
Expand Down Expand Up @@ -45,3 +54,116 @@ func GetTLSConfig() (*tls.Config, error) {
RootCAs: rootCAs,
}, nil
}

func CheckAndRenewCertificate() error {
const renewalWindow = 30 * 24 * time.Hour // Renew if certificate expires in less than 30 days

certReg, err := registry.GetEntry(registry.AUTH, "Cert", true)
if err != nil {
return fmt.Errorf("CheckAndRenewCertificate: failed to retrieve certificate - %w", err)
}

block, _ := pem.Decode([]byte(certReg.Value))
if block == nil {
return fmt.Errorf("CheckAndRenewCertificate: failed to decode PEM block")
}

cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return fmt.Errorf("CheckAndRenewCertificate: failed to parse certificate - %w", err)
}

now := time.Now()
timeUntilExpiry := cert.NotAfter.Sub(now)

switch {
case cert.NotAfter.Before(now):
_ = registry.DeleteEntry(registry.AUTH, "Cert")
_ = registry.DeleteEntry(registry.AUTH, "Priv")

return fmt.Errorf("Certificate has expired. This agent needs to be bootstrapped again.")
case timeUntilExpiry < renewalWindow:
fmt.Printf("Certificate expires in %v hours. Renewing...\n", timeUntilExpiry.Hours())
return renewCertificate()
default:
fmt.Printf("Certificate valid for %v days. No renewal needed.\n", timeUntilExpiry.Hours()/24)
return nil
}
}

func renewCertificate() error {
hostname, _ := os.Hostname()

csr, privKey, err := certificates.GenerateCSR(hostname, 2048)
if err != nil {
return fmt.Errorf("Bootstrap: generating csr failed -> %w", err)
}

encodedCSR := base64.StdEncoding.EncodeToString(csr)

reqBody, err := json.Marshal(&BootstrapRequest{
Hostname: hostname,
Drives: utils.GetLocalDrives(),
CSR: encodedCSR,
})
if err != nil {
return fmt.Errorf("failed to marshal bootstrap request: %w", err)
}

renewResp := &BootstrapResponse{}

_, err = ProxmoxHTTPRequest(http.MethodPost, "/plus/agent/renew", bytes.NewBuffer(reqBody), &renewResp)
if err != nil {
return fmt.Errorf("failed to fetch renewed certificate: %w", err)
}

decodedCA, err := base64.StdEncoding.DecodeString(renewResp.CA)
if err != nil {
return fmt.Errorf("Renew: error decoding ca content (%s) -> %w", string(renewResp.CA), err)
}

decodedCert, err := base64.StdEncoding.DecodeString(renewResp.Cert)
if err != nil {
return fmt.Errorf("Renew: error decoding cert content (%s) -> %w", string(renewResp.Cert), err)
}

privKeyPEM := certificates.EncodeKeyPEM(privKey)

caEntry := registry.RegistryEntry{
Key: "ServerCA",
Value: string(decodedCA),
Path: registry.AUTH,
IsSecret: true,
}

certEntry := registry.RegistryEntry{
Key: "Cert",
Value: string(decodedCert),
Path: registry.AUTH,
IsSecret: true,
}

privEntry := registry.RegistryEntry{
Key: "Priv",
Value: string(privKeyPEM),
Path: registry.AUTH,
IsSecret: true,
}

err = registry.CreateEntry(&caEntry)
if err != nil {
return fmt.Errorf("Renew: error storing ca to registry -> %w", err)
}

err = registry.CreateEntry(&certEntry)
if err != nil {
return fmt.Errorf("Renew: error storing cert to registry -> %w", err)
}

err = registry.CreateEntry(&privEntry)
if err != nil {
return fmt.Errorf("Renew: error storing priv to registry -> %w", err)
}

return nil
}
1 change: 1 addition & 0 deletions internal/backend/mount/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func Mount(storeInstance *store.Store, target *types.Target) (*AgentMount, error

// If all retries failed, clean up and return error
agentMount.Unmount()
agentMount.CloseMount()
return nil, fmt.Errorf("Mount: error mounting NFS share after %d attempts -> %w", maxRetries, lastErr)
}

Expand Down
77 changes: 77 additions & 0 deletions internal/config/mutex_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package config

import (
"path/filepath"
"sync"
)

type FileMutexManager struct {
mu sync.RWMutex
locks map[string]*sync.RWMutex
refCnt map[string]int
}

func NewFileMutexManager() *FileMutexManager {
return &FileMutexManager{
locks: make(map[string]*sync.RWMutex),
refCnt: make(map[string]int),
}
}

func (fm *FileMutexManager) getLock(path string) *sync.RWMutex {
fm.mu.Lock()
defer fm.mu.Unlock()

absPath, err := filepath.Abs(path)
if err != nil {
absPath = path
}

if lock, exists := fm.locks[absPath]; exists {
fm.refCnt[absPath]++
return lock
}

lock := &sync.RWMutex{}
fm.locks[absPath] = lock
fm.refCnt[absPath] = 1
return lock
}

func (fm *FileMutexManager) releaseLock(path string) {
fm.mu.Lock()
defer fm.mu.Unlock()

absPath, err := filepath.Abs(path)
if err != nil {
absPath = path
}

if fm.refCnt[absPath] > 1 {
fm.refCnt[absPath]--
return
}

delete(fm.locks, absPath)
delete(fm.refCnt, absPath)
}

func (fm *FileMutexManager) WithReadLock(path string, fn func() error) error {
lock := fm.getLock(path)
lock.RLock()
defer func() {
lock.RUnlock()
fm.releaseLock(path)
}()
return fn()
}

func (fm *FileMutexManager) WithWriteLock(path string, fn func() error) error {
lock := fm.getLock(path)
lock.Lock()
defer func() {
lock.Unlock()
fm.releaseLock(path)
}()
return fn()
}
Loading

0 comments on commit c390a6c

Please sign in to comment.