diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..176a458 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto diff --git a/.github/workflows/development.yml b/.github/workflows/development.yml index 2997a72..aea4200 100644 --- a/.github/workflows/development.yml +++ b/.github/workflows/development.yml @@ -1,106 +1,106 @@ -on: - push: - branches: - - main - -permissions: - contents: write - packages: write - -jobs: - test: - strategy: - matrix: - os: [ubuntu-latest] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 - with: - go-version: '1.23' - cache: true - - run: go mod download - - name: Run Linux tests - shell: bash - if: matrix.os == 'ubuntu-latest' - run: GODEBUG=gctrace=1 go test -v -race ./... -count=1 -timeout 1m - - dev-build-linux-amd64: - name: dev build linux/amd64 - runs-on: ubuntu-latest - needs: test - steps: - - uses: actions/checkout@v4 - - uses: wangyoucao577/go-release-action@v1 - id: go_build - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - goos: linux - goarch: amd64 - compress_assets: false - executable_compression: upx - project_path: ./cmd/pbs_plus - overwrite: true - release_tag: dev - - name: pre-packaging script - env: - BINARY_PATH: ${{ steps.go_build.outputs.release_asset_dir }} - run: ./build/package/pre-packaging.sh - - uses: jiro4989/build-deb-action@v3 - with: - package: ${{ github.event.repository.name }} - package_root: build/package/debian - maintainer: Son Roy Almerol - version: 'refs/tags/v0.0.0' - arch: 'amd64' - depends: 'proxmox-backup-server (>= 3.2), proxmox-backup-client (>= 3.2.5), rclone, fuse3' - desc: 'PBS Plus is a project focused on extending Proxmox Backup Server (PBS) with advanced features to create a more competitive backup solution' - homepage: 'https://github.com/${{ github.repository }}' - - name: Pre-release dev build - uses: softprops/action-gh-release@v1 - with: - tag_name: dev - files: ./*.deb - prerelease: true - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - dev-build-windows-amd64-agent: - name: dev build agent windows/amd64 - runs-on: ubuntu-latest - needs: test - steps: - - uses: actions/checkout@v4 - - uses: wangyoucao577/go-release-action@v1 - id: go_build - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - goos: windows - goarch: amd64 - compress_assets: false - executable_compression: upx - binary_name: pbs-plus-agent - project_path: ./cmd/windows_agent - ldflags: -H=windowsgui - overwrite: true - release_tag: dev - - dev-build-windows-amd64-updater: - name: dev build updater windows/amd64 - runs-on: ubuntu-latest - needs: test - steps: - - uses: actions/checkout@v4 - - uses: wangyoucao577/go-release-action@v1 - id: go_build_updater - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - goos: windows - goarch: amd64 - compress_assets: false - executable_compression: upx - binary_name: pbs-plus-updater - project_path: ./cmd/windows_updater - ldflags: -H=windowsgui - overwrite: true - release_tag: dev +on: + push: + branches: + - main + +permissions: + contents: write + packages: write + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version: '1.23' + cache: true + - run: go mod download + - name: Run Linux tests + shell: bash + if: matrix.os == 'ubuntu-latest' + run: GODEBUG=gctrace=1 go test -v -race ./... -count=1 -timeout 1m + + dev-build-linux-amd64: + name: dev build linux/amd64 + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + - uses: wangyoucao577/go-release-action@v1 + id: go_build + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + goos: linux + goarch: amd64 + compress_assets: false + executable_compression: upx + project_path: ./cmd/pbs_plus + overwrite: true + release_tag: dev + - name: pre-packaging script + env: + BINARY_PATH: ${{ steps.go_build.outputs.release_asset_dir }} + run: ./build/package/pre-packaging.sh + - uses: jiro4989/build-deb-action@v3 + with: + package: ${{ github.event.repository.name }} + package_root: build/package/debian + maintainer: Son Roy Almerol + version: 'refs/tags/v0.0.0' + arch: 'amd64' + depends: 'proxmox-backup-server (>= 3.2), proxmox-backup-client (>= 3.2.5), rclone, fuse3' + desc: 'PBS Plus is a project focused on extending Proxmox Backup Server (PBS) with advanced features to create a more competitive backup solution' + homepage: 'https://github.com/${{ github.repository }}' + - name: Pre-release dev build + uses: softprops/action-gh-release@v1 + with: + tag_name: dev + files: ./*.deb + prerelease: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + dev-build-windows-amd64-agent: + name: dev build agent windows/amd64 + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + - uses: wangyoucao577/go-release-action@v1 + id: go_build + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + goos: windows + goarch: amd64 + compress_assets: false + executable_compression: upx + binary_name: pbs-plus-agent + project_path: ./cmd/windows_agent + ldflags: -H=windowsgui + overwrite: true + release_tag: dev + + dev-build-windows-amd64-updater: + name: dev build updater windows/amd64 + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + - uses: wangyoucao577/go-release-action@v1 + id: go_build_updater + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + goos: windows + goarch: amd64 + compress_assets: false + executable_compression: upx + binary_name: pbs-plus-updater + project_path: ./cmd/windows_updater + ldflags: -H=windowsgui + overwrite: true + release_tag: dev diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 53bccd8..6657065 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,131 +1,131 @@ -on: - release: - types: [created] - -permissions: - contents: write - packages: write - -jobs: - release-linux-amd64: - name: release linux/amd64 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: wangyoucao577/go-release-action@v1 - id: go_build - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - goos: linux - goarch: amd64 - compress_assets: false - executable_compression: upx - project_path: ./cmd/pbs_plus - ldflags: "-X 'main.Version=${{ github.event.release.tag_name }}'" - - name: pre-packaging script - env: - BINARY_PATH: ${{steps.go_build.outputs.release_asset_dir}} - run: ./build/package/pre-packaging.sh - - uses: jiro4989/build-deb-action@v3 - with: - package: ${{ github.event.repository.name }} - package_root: build/package/debian - maintainer: Son Roy Almerol - version: ${{ github.ref }} # refs/tags/v*.*.* - arch: 'amd64' - depends: 'proxmox-backup-server (>= 3.2), proxmox-backup-client (>= 3.2.5), rclone, fuse3' - desc: 'PBS Plus is a project focused on extending Proxmox Backup Server (PBS) with advanced features to create a more competitive backup solution' - homepage: 'https://github.com/${{ github.repository }}' - - name: Publish Release Assets - uses: softprops/action-gh-release@v1 - with: - tag: ${{ github.event.release.tag_name }} - files: ./*.deb - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - release-windows-amd64-agent: - name: release agent windows/amd64 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: wangyoucao577/go-release-action@v1 - id: go_build - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - goos: windows - goarch: amd64 - compress_assets: false - executable_compression: upx - binary_name: pbs-plus-agent - project_path: ./cmd/windows_agent - ldflags: "-H=windowsgui -X 'main.Version=${{ github.event.release.tag_name }}'" - - uses: actions/upload-artifact@v4 - with: - name: windows-binary - path: ${{steps.go_build.outputs.release_asset_dir}}/pbs-plus-agent.exe - - release-windows-amd64-updater: - name: release updater windows/amd64 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: wangyoucao577/go-release-action@v1 - id: go_build_updater - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - goos: windows - goarch: amd64 - compress_assets: false - executable_compression: upx - binary_name: pbs-plus-updater - project_path: ./cmd/windows_updater - ldflags: "-H=windowsgui" - - uses: actions/upload-artifact@v4 - with: - name: windows-updater-binary - path: ${{steps.go_build_updater.outputs.release_asset_dir}}/pbs-plus-updater.exe - - release-windows-amd64-agent-installer: - name: release agent installer windows/amd64 - runs-on: windows-latest - needs: - - release-windows-amd64-agent - - release-windows-amd64-updater - steps: - - uses: actions/checkout@v4 - - uses: actions/download-artifact@v4 - with: - name: windows-binary - path: ./build/package/windows/ - - uses: actions/download-artifact@v4 - with: - name: windows-updater-binary - path: ./build/package/windows/ - - id: version - shell: pwsh - run: | - $version = $env:GITHUB_REF -replace 'refs/tags/v', '' - echo "version=$version" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append - - env: - MSI_NAME: "pbs-plus-agent-${{ github.event.release.tag_name }}-windows-installer.msi" - VERSION: ${{ env.version }} - shell: pwsh - run: | - choco install go-msi - Import-Module $env:ChocolateyInstall\helpers\chocolateyProfile.psm1 - refreshenv - $tempDir = Join-Path -Path $Env:GITHUB_WORKSPACE/build/package/windows -ChildPath "temp" - New-Item -Path $tempDir -ItemType Directory -Force - $env:TEMP = $tempDir - $env:TMP = $tempDir - cd ./build/package/windows - go-msi make --msi $env:MSI_NAME --version $env:VERSION - - name: Publish Release Assets - uses: softprops/action-gh-release@v1 - with: - tag: ${{ github.event.release.tag_name }} - files: ./build/package/windows/*.msi - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - +on: + release: + types: [created] + +permissions: + contents: write + packages: write + +jobs: + release-linux-amd64: + name: release linux/amd64 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: wangyoucao577/go-release-action@v1 + id: go_build + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + goos: linux + goarch: amd64 + compress_assets: false + executable_compression: upx + project_path: ./cmd/pbs_plus + ldflags: "-X 'main.Version=${{ github.event.release.tag_name }}'" + - name: pre-packaging script + env: + BINARY_PATH: ${{steps.go_build.outputs.release_asset_dir}} + run: ./build/package/pre-packaging.sh + - uses: jiro4989/build-deb-action@v3 + with: + package: ${{ github.event.repository.name }} + package_root: build/package/debian + maintainer: Son Roy Almerol + version: ${{ github.ref }} # refs/tags/v*.*.* + arch: 'amd64' + depends: 'proxmox-backup-server (>= 3.2), proxmox-backup-client (>= 3.2.5), rclone, fuse3' + desc: 'PBS Plus is a project focused on extending Proxmox Backup Server (PBS) with advanced features to create a more competitive backup solution' + homepage: 'https://github.com/${{ github.repository }}' + - name: Publish Release Assets + uses: softprops/action-gh-release@v1 + with: + tag: ${{ github.event.release.tag_name }} + files: ./*.deb + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + release-windows-amd64-agent: + name: release agent windows/amd64 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: wangyoucao577/go-release-action@v1 + id: go_build + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + goos: windows + goarch: amd64 + compress_assets: false + executable_compression: upx + binary_name: pbs-plus-agent + project_path: ./cmd/windows_agent + ldflags: "-H=windowsgui -X 'main.Version=${{ github.event.release.tag_name }}'" + - uses: actions/upload-artifact@v4 + with: + name: windows-binary + path: ${{steps.go_build.outputs.release_asset_dir}}/pbs-plus-agent.exe + + release-windows-amd64-updater: + name: release updater windows/amd64 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: wangyoucao577/go-release-action@v1 + id: go_build_updater + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + goos: windows + goarch: amd64 + compress_assets: false + executable_compression: upx + binary_name: pbs-plus-updater + project_path: ./cmd/windows_updater + ldflags: "-H=windowsgui" + - uses: actions/upload-artifact@v4 + with: + name: windows-updater-binary + path: ${{steps.go_build_updater.outputs.release_asset_dir}}/pbs-plus-updater.exe + + release-windows-amd64-agent-installer: + name: release agent installer windows/amd64 + runs-on: windows-latest + needs: + - release-windows-amd64-agent + - release-windows-amd64-updater + steps: + - uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: windows-binary + path: ./build/package/windows/ + - uses: actions/download-artifact@v4 + with: + name: windows-updater-binary + path: ./build/package/windows/ + - id: version + shell: pwsh + run: | + $version = $env:GITHUB_REF -replace 'refs/tags/v', '' + echo "version=$version" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + - env: + MSI_NAME: "pbs-plus-agent-${{ github.event.release.tag_name }}-windows-installer.msi" + VERSION: ${{ env.version }} + shell: pwsh + run: | + choco install go-msi + Import-Module $env:ChocolateyInstall\helpers\chocolateyProfile.psm1 + refreshenv + $tempDir = Join-Path -Path $Env:GITHUB_WORKSPACE/build/package/windows -ChildPath "temp" + New-Item -Path $tempDir -ItemType Directory -Force + $env:TEMP = $tempDir + $env:TMP = $tempDir + cd ./build/package/windows + go-msi make --msi $env:MSI_NAME --version $env:VERSION + - name: Publish Release Assets + uses: softprops/action-gh-release@v1 + with: + tag: ${{ github.event.release.tag_name }} + files: ./build/package/windows/*.msi + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + diff --git a/build/package/windows/register-server.reg b/build/package/windows/register-server.reg index 2ef6cad..1922800 100644 --- a/build/package/windows/register-server.reg +++ b/build/package/windows/register-server.reg @@ -1,5 +1,5 @@ -Windows Registry Editor Version 5.00 - -[HKEY_LOCAL_MACHINE\SOFTWARE\PBSPlus\Config] -"ServerURL"="" -"BootstrapToken"="" +Windows Registry Editor Version 5.00 + +[HKEY_LOCAL_MACHINE\SOFTWARE\PBSPlus\Config] +"ServerURL"="" +"BootstrapToken"="" diff --git a/cmd/pbs_plus/main.go b/cmd/pbs_plus/main.go index 71ce84e..0f19a8b 100644 --- a/cmd/pbs_plus/main.go +++ b/cmd/pbs_plus/main.go @@ -1,264 +1,264 @@ -//go:build linux - -package main - -import ( - "context" - "flag" - "log" - "net/http" - "os" - "os/exec" - "path/filepath" - "time" - - "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" - "github.com/sonroyaalmerol/pbs-plus/internal/auth/server" - "github.com/sonroyaalmerol/pbs-plus/internal/auth/token" - "github.com/sonroyaalmerol/pbs-plus/internal/backend/backup" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/agents" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/exclusions" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/jobs" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/plus" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/targets" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/tokens" - mw "github.com/sonroyaalmerol/pbs-plus/internal/proxy/middlewares" - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "github.com/sonroyaalmerol/pbs-plus/internal/websockets" -) - -var Version = "v0.0.0" - -func main() { - err := syslog.InitializeLogger() - if err != nil { - log.Fatalf("Failed to initialize logger: %s", err) - } - proxmox.InitializeProxmox() - - jobRun := flag.String("job", "", "Job ID to execute") - flag.Parse() - - var wsHub *websockets.Server - wsHub = nil - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if *jobRun == "" { - wsHub = websockets.NewServer(ctx) - } - - storeInstance, err := store.Initialize(wsHub, nil) - if err != nil { - syslog.L.Errorf("Failed to initialize store: %v", err) - return - } - - apiToken, err := proxmox.GetAPITokenFromFile() - if err != nil { - syslog.L.Error(err) - } - proxmox.Session.APIToken = apiToken - - // Handle single job execution - if *jobRun != "" { - if proxmox.Session.APIToken == nil { - return - } - - jobTask, err := storeInstance.Database.GetJob(*jobRun) - if err != nil { - syslog.L.Error(err) - return - } - - if jobTask.LastRunState == nil && jobTask.LastRunUpid != "" { - syslog.L.Info("A job is still running, skipping this schedule.") - return - } - - op, err := backup.RunBackup(jobTask, storeInstance, true) - if err != nil { - syslog.L.Error(err) - return - } - - op.Wait() - return - } - - pbsJsLocation := "/usr/share/javascript/proxmox-backup/js/proxmox-backup-gui.js" - err = proxy.MountCompiledJS(pbsJsLocation) - if err != nil { - syslog.L.Errorf("Modified JS mounting failed: %v", err) - return - } - - proxmoxLibLocation := "/usr/share/javascript/proxmox-widget-toolkit/proxmoxlib.js" - err = proxy.MountModdedProxmoxLib(proxmoxLibLocation) - if err != nil { - syslog.L.Errorf("Modified JS mounting failed: %v", err) - return - } - - defer func() { - _ = proxy.UnmountModdedFile(pbsJsLocation) - _ = proxy.UnmountModdedFile(proxmoxLibLocation) - }() - - certOpts := certificates.DefaultOptions() - generator, err := certificates.NewGenerator(certOpts) - if err != nil { - syslog.L.Errorf("Initializing certificate generator failed: %v", err) - return - } - - csrfKey, err := os.ReadFile("/etc/proxmox-backup/csrf.key") - if err != nil { - syslog.L.Errorf("CSRF key not found: %v", err) - return - } - - serverConfig := server.DefaultConfig() - serverConfig.CertFile = filepath.Join(certOpts.OutputDir, "server.crt") - serverConfig.KeyFile = filepath.Join(certOpts.OutputDir, "server.key") - serverConfig.CAFile = filepath.Join(certOpts.OutputDir, "ca.crt") - serverConfig.CAKey = filepath.Join(certOpts.OutputDir, "ca.key") - serverConfig.TokenSecret = string(csrfKey) - - if err := generator.ValidateExistingCerts(); err != nil { - if err := generator.GenerateCA(); err != nil { - syslog.L.Errorf("Generating certificates failed: %v", err) - return - } - - if err := generator.GenerateCert("server"); err != nil { - syslog.L.Errorf("Generating certificates failed: %v", err) - return - } - } - - if err := serverConfig.Validate(); err != nil { - syslog.L.Errorf("Validating server config failed: %v", err) - return - } - - storeInstance.CertGenerator = generator - - err = os.Chown(serverConfig.KeyFile, 0, 34) - if err != nil { - syslog.L.Errorf("Changing permissions of key failed: %v", err) - return - } - - err = os.Chown(serverConfig.CertFile, 0, 34) - if err != nil { - syslog.L.Errorf("Changing permissions of cert failed: %v", err) - return - } - - err = serverConfig.Mount() - if err != nil { - syslog.L.Errorf("Mounting certificates failed: %v", err) - return - } - defer func() { - _ = serverConfig.Unmount() - }() - - proxy := exec.Command("/usr/bin/systemctl", "restart", "proxmox-backup-proxy") - proxy.Env = os.Environ() - _ = proxy.Run() - - // Initialize token manager - tokenManager, err := token.NewManager(token.Config{ - TokenExpiration: serverConfig.TokenExpiration, - SecretKey: serverConfig.TokenSecret, - }) - if err != nil { - syslog.L.Errorf("Initializing token manager failed: %v", err) - return - } - storeInstance.Database.TokenManager = tokenManager - - // Setup HTTP server - tlsConfig, err := serverConfig.LoadTLSConfig() - if err != nil { - return - } - - caRenewalCtx, cancelRenewal := context.WithCancel(context.Background()) - defer cancelRenewal() - go func() { - for { - select { - case <-caRenewalCtx.Done(): - return - case <-time.After(time.Hour): - if err := generator.ValidateExistingCerts(); err != nil { - if err := generator.GenerateCA(); err != nil { - syslog.L.Errorf("Generating certificates failed: %v", err) - } - - if err := generator.GenerateCert("server"); err != nil { - syslog.L.Errorf("Generating certificates failed: %v", err) - } - } - - } - } - }() - - // Initialize router with Go 1.22's new pattern syntax - mux := http.NewServeMux() - - // API routes - mux.HandleFunc("/plus/token", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, plus.TokenHandler(storeInstance)))) - mux.HandleFunc("/api2/json/plus/version", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, plus.VersionHandler(storeInstance, Version)))) - mux.HandleFunc("/api2/json/plus/binary", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, plus.DownloadBinary(storeInstance, Version)))) - mux.HandleFunc("/api2/json/plus/binary/checksum", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, plus.DownloadChecksum(storeInstance, Version)))) - mux.HandleFunc("/api2/json/d2d/backup", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.D2DJobHandler(storeInstance)))) - mux.HandleFunc("/api2/json/d2d/target", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, targets.D2DTargetHandler(storeInstance)))) - mux.HandleFunc("/api2/json/d2d/target/agent", mw.AgentOnly(storeInstance, mw.CORS(storeInstance, targets.D2DTargetAgentHandler(storeInstance)))) - mux.HandleFunc("/api2/json/d2d/token", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, tokens.D2DTokenHandler(storeInstance)))) - mux.HandleFunc("/api2/json/d2d/exclusion", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, exclusions.D2DExclusionHandler(storeInstance)))) - mux.HandleFunc("/api2/json/d2d/agent-log", mw.AgentOnly(storeInstance, mw.CORS(storeInstance, agents.AgentLogHandler(storeInstance)))) - - // ExtJS routes with path parameters - mux.HandleFunc("/api2/extjs/d2d/backup/{job}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.ExtJsJobRunHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/d2d-target", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, targets.ExtJsTargetHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/d2d-target/{target}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, targets.ExtJsTargetSingleHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/d2d-token", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, tokens.ExtJsTokenHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/d2d-token/{token}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, tokens.ExtJsTokenSingleHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/d2d-exclusion", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, exclusions.ExtJsExclusionHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/d2d-exclusion/{exclusion}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, exclusions.ExtJsExclusionSingleHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/disk-backup-job", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.ExtJsJobHandler(storeInstance)))) - mux.HandleFunc("/api2/extjs/config/disk-backup-job/{job}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.ExtJsJobSingleHandler(storeInstance)))) - - // WebSocket-related routes - mux.HandleFunc("/plus/ws", mw.AgentOnly(storeInstance, plus.WSHandler(storeInstance))) - mux.HandleFunc("/plus/mount/{target}/{drive}", mw.ServerOnly(storeInstance, plus.MountHandler(storeInstance))) - - // Agent auth routes - mux.HandleFunc("/plus/agent/bootstrap", mw.CORS(storeInstance, agents.AgentBootstrapHandler(storeInstance))) - mux.HandleFunc("/plus/agent/renew", mw.AgentOnly(storeInstance, mw.CORS(storeInstance, agents.AgentRenewHandler(storeInstance)))) - - server := &http.Server{ - Addr: serverConfig.Address, - Handler: mux, - TLSConfig: tlsConfig, - ReadTimeout: serverConfig.ReadTimeout, - WriteTimeout: serverConfig.WriteTimeout, - MaxHeaderBytes: serverConfig.MaxHeaderBytes, - } - - syslog.L.Info("Starting proxy server on :8008") - if err := server.ListenAndServeTLS(serverConfig.CertFile, serverConfig.KeyFile); err != nil { - if syslog.L != nil { - syslog.L.Errorf("Server failed: %v", err) - } - } -} +//go:build linux + +package main + +import ( + "context" + "flag" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" + "github.com/sonroyaalmerol/pbs-plus/internal/auth/server" + "github.com/sonroyaalmerol/pbs-plus/internal/auth/token" + "github.com/sonroyaalmerol/pbs-plus/internal/backend/backup" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/agents" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/exclusions" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/jobs" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/plus" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/targets" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers/tokens" + mw "github.com/sonroyaalmerol/pbs-plus/internal/proxy/middlewares" + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "github.com/sonroyaalmerol/pbs-plus/internal/websockets" +) + +var Version = "v0.0.0" + +func main() { + err := syslog.InitializeLogger() + if err != nil { + log.Fatalf("Failed to initialize logger: %s", err) + } + proxmox.InitializeProxmox() + + jobRun := flag.String("job", "", "Job ID to execute") + flag.Parse() + + var wsHub *websockets.Server + wsHub = nil + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if *jobRun == "" { + wsHub = websockets.NewServer(ctx) + } + + storeInstance, err := store.Initialize(wsHub, nil) + if err != nil { + syslog.L.Errorf("Failed to initialize store: %v", err) + return + } + + apiToken, err := proxmox.GetAPITokenFromFile() + if err != nil { + syslog.L.Error(err) + } + proxmox.Session.APIToken = apiToken + + // Handle single job execution + if *jobRun != "" { + if proxmox.Session.APIToken == nil { + return + } + + jobTask, err := storeInstance.Database.GetJob(*jobRun) + if err != nil { + syslog.L.Error(err) + return + } + + if jobTask.LastRunState == nil && jobTask.LastRunUpid != "" { + syslog.L.Info("A job is still running, skipping this schedule.") + return + } + + op, err := backup.RunBackup(jobTask, storeInstance, true) + if err != nil { + syslog.L.Error(err) + return + } + + op.Wait() + return + } + + pbsJsLocation := "/usr/share/javascript/proxmox-backup/js/proxmox-backup-gui.js" + err = proxy.MountCompiledJS(pbsJsLocation) + if err != nil { + syslog.L.Errorf("Modified JS mounting failed: %v", err) + return + } + + proxmoxLibLocation := "/usr/share/javascript/proxmox-widget-toolkit/proxmoxlib.js" + err = proxy.MountModdedProxmoxLib(proxmoxLibLocation) + if err != nil { + syslog.L.Errorf("Modified JS mounting failed: %v", err) + return + } + + defer func() { + _ = proxy.UnmountModdedFile(pbsJsLocation) + _ = proxy.UnmountModdedFile(proxmoxLibLocation) + }() + + certOpts := certificates.DefaultOptions() + generator, err := certificates.NewGenerator(certOpts) + if err != nil { + syslog.L.Errorf("Initializing certificate generator failed: %v", err) + return + } + + csrfKey, err := os.ReadFile("/etc/proxmox-backup/csrf.key") + if err != nil { + syslog.L.Errorf("CSRF key not found: %v", err) + return + } + + serverConfig := server.DefaultConfig() + serverConfig.CertFile = filepath.Join(certOpts.OutputDir, "server.crt") + serverConfig.KeyFile = filepath.Join(certOpts.OutputDir, "server.key") + serverConfig.CAFile = filepath.Join(certOpts.OutputDir, "ca.crt") + serverConfig.CAKey = filepath.Join(certOpts.OutputDir, "ca.key") + serverConfig.TokenSecret = string(csrfKey) + + if err := generator.ValidateExistingCerts(); err != nil { + if err := generator.GenerateCA(); err != nil { + syslog.L.Errorf("Generating certificates failed: %v", err) + return + } + + if err := generator.GenerateCert("server"); err != nil { + syslog.L.Errorf("Generating certificates failed: %v", err) + return + } + } + + if err := serverConfig.Validate(); err != nil { + syslog.L.Errorf("Validating server config failed: %v", err) + return + } + + storeInstance.CertGenerator = generator + + err = os.Chown(serverConfig.KeyFile, 0, 34) + if err != nil { + syslog.L.Errorf("Changing permissions of key failed: %v", err) + return + } + + err = os.Chown(serverConfig.CertFile, 0, 34) + if err != nil { + syslog.L.Errorf("Changing permissions of cert failed: %v", err) + return + } + + err = serverConfig.Mount() + if err != nil { + syslog.L.Errorf("Mounting certificates failed: %v", err) + return + } + defer func() { + _ = serverConfig.Unmount() + }() + + proxy := exec.Command("/usr/bin/systemctl", "restart", "proxmox-backup-proxy") + proxy.Env = os.Environ() + _ = proxy.Run() + + // Initialize token manager + tokenManager, err := token.NewManager(token.Config{ + TokenExpiration: serverConfig.TokenExpiration, + SecretKey: serverConfig.TokenSecret, + }) + if err != nil { + syslog.L.Errorf("Initializing token manager failed: %v", err) + return + } + storeInstance.Database.TokenManager = tokenManager + + // Setup HTTP server + tlsConfig, err := serverConfig.LoadTLSConfig() + if err != nil { + return + } + + caRenewalCtx, cancelRenewal := context.WithCancel(context.Background()) + defer cancelRenewal() + go func() { + for { + select { + case <-caRenewalCtx.Done(): + return + case <-time.After(time.Hour): + if err := generator.ValidateExistingCerts(); err != nil { + if err := generator.GenerateCA(); err != nil { + syslog.L.Errorf("Generating certificates failed: %v", err) + } + + if err := generator.GenerateCert("server"); err != nil { + syslog.L.Errorf("Generating certificates failed: %v", err) + } + } + + } + } + }() + + // Initialize router with Go 1.22's new pattern syntax + mux := http.NewServeMux() + + // API routes + mux.HandleFunc("/plus/token", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, plus.TokenHandler(storeInstance)))) + mux.HandleFunc("/api2/json/plus/version", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, plus.VersionHandler(storeInstance, Version)))) + mux.HandleFunc("/api2/json/plus/binary", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, plus.DownloadBinary(storeInstance, Version)))) + mux.HandleFunc("/api2/json/plus/binary/checksum", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, plus.DownloadChecksum(storeInstance, Version)))) + mux.HandleFunc("/api2/json/d2d/backup", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.D2DJobHandler(storeInstance)))) + mux.HandleFunc("/api2/json/d2d/target", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, targets.D2DTargetHandler(storeInstance)))) + mux.HandleFunc("/api2/json/d2d/target/agent", mw.AgentOnly(storeInstance, mw.CORS(storeInstance, targets.D2DTargetAgentHandler(storeInstance)))) + mux.HandleFunc("/api2/json/d2d/token", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, tokens.D2DTokenHandler(storeInstance)))) + mux.HandleFunc("/api2/json/d2d/exclusion", mw.AgentOrServer(storeInstance, mw.CORS(storeInstance, exclusions.D2DExclusionHandler(storeInstance)))) + mux.HandleFunc("/api2/json/d2d/agent-log", mw.AgentOnly(storeInstance, mw.CORS(storeInstance, agents.AgentLogHandler(storeInstance)))) + + // ExtJS routes with path parameters + mux.HandleFunc("/api2/extjs/d2d/backup/{job}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.ExtJsJobRunHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/d2d-target", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, targets.ExtJsTargetHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/d2d-target/{target}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, targets.ExtJsTargetSingleHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/d2d-token", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, tokens.ExtJsTokenHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/d2d-token/{token}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, tokens.ExtJsTokenSingleHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/d2d-exclusion", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, exclusions.ExtJsExclusionHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/d2d-exclusion/{exclusion}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, exclusions.ExtJsExclusionSingleHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/disk-backup-job", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.ExtJsJobHandler(storeInstance)))) + mux.HandleFunc("/api2/extjs/config/disk-backup-job/{job}", mw.ServerOnly(storeInstance, mw.CORS(storeInstance, jobs.ExtJsJobSingleHandler(storeInstance)))) + + // WebSocket-related routes + mux.HandleFunc("/plus/ws", mw.AgentOnly(storeInstance, plus.WSHandler(storeInstance))) + mux.HandleFunc("/plus/mount/{target}/{drive}", mw.ServerOnly(storeInstance, plus.MountHandler(storeInstance))) + + // Agent auth routes + mux.HandleFunc("/plus/agent/bootstrap", mw.CORS(storeInstance, agents.AgentBootstrapHandler(storeInstance))) + mux.HandleFunc("/plus/agent/renew", mw.AgentOnly(storeInstance, mw.CORS(storeInstance, agents.AgentRenewHandler(storeInstance)))) + + server := &http.Server{ + Addr: serverConfig.Address, + Handler: mux, + TLSConfig: tlsConfig, + ReadTimeout: serverConfig.ReadTimeout, + WriteTimeout: serverConfig.WriteTimeout, + MaxHeaderBytes: serverConfig.MaxHeaderBytes, + } + + syslog.L.Info("Starting proxy server on :8008") + if err := server.ListenAndServeTLS(serverConfig.CertFile, serverConfig.KeyFile); err != nil { + if syslog.L != nil { + syslog.L.Errorf("Server failed: %v", err) + } + } +} diff --git a/cmd/windows_agent/main.go b/cmd/windows_agent/main.go index 391edae..5c184c1 100644 --- a/cmd/windows_agent/main.go +++ b/cmd/windows_agent/main.go @@ -1,228 +1,228 @@ -//go:build windows - -package main - -import ( - "fmt" - "os" - "os/exec" - "path/filepath" - "runtime/debug" - "sync" - "syscall" - "time" - - "github.com/kardianos/service" - "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" -) - -var Version = "v0.0.0" -var ( - mutex sync.Mutex - handle windows.Handle -) - -// watchdogService wraps the original service and adds resilience -type watchdogService struct { - *agentService - restartCount int - lastRestartTime time.Time - maxRestarts int - restartWindow time.Duration -} - -func newWatchdogService(original *agentService) *watchdogService { - return &watchdogService{ - agentService: original, - maxRestarts: 5, // Max restarts in window - restartWindow: time.Hour * 1, // Reset counter after 1 hour - } -} - -func (w *watchdogService) resetRestartCounter() { - if time.Since(w.lastRestartTime) > w.restartWindow { - w.restartCount = 0 - } -} - -func (w *watchdogService) shouldRestart() bool { - w.resetRestartCounter() - return w.restartCount < w.maxRestarts -} - -func (w *watchdogService) Start(s service.Service) error { - go func() { - for { - err := w.runWithRecovery(s) - if err != nil { - syslog.L.Errorf("Service failed with error: %v - Attempting restart", err) - - w.restartCount++ - w.lastRestartTime = time.Now() - - if !w.shouldRestart() { - syslog.L.Errorf("Too many restart attempts (%d) within window. Waiting for window reset.", w.restartCount) - time.Sleep(w.restartWindow) - w.restartCount = 0 - } - - time.Sleep(time.Second * 5) // Brief delay before restart - continue - } - break // Clean exit - } - }() - return nil -} - -func (w *watchdogService) runWithRecovery(s service.Service) (err error) { - defer func() { - if r := recover(); r != nil { - stack := string(debug.Stack()) - err = fmt.Errorf("service panicked: %v\nStack:\n%s", r, stack) - syslog.L.Error(err) - } - }() - - return w.agentService.Start(s) -} - -func (w *watchdogService) Stop(s service.Service) error { - return w.agentService.Stop(s) -} - -func main() { - constants.Version = Version - - svcConfig := &service.Config{ - Name: "PBSPlusAgent", - DisplayName: "PBS Plus Agent", - Description: "Agent for orchestrating backups with PBS Plus", - UserName: "", - } - - prg := &agentService{} - watchdog := newWatchdogService(prg) - - s, err := service.New(watchdog, svcConfig) - if err != nil { - fmt.Printf("Failed to initialize service: %v\n", err) - return - } - prg.svc = s - - err = syslog.InitializeLogger(s) - if err != nil { - fmt.Printf("Failed to initialize logger: %v\n", err) - return - } - - if err := createMutex(); err != nil { - syslog.L.Errorf("Error: %v", err) - os.Exit(1) - } - defer releaseMutex() - - err = prg.writeVersionToFile() - if err != nil { - fmt.Printf("Error writing version to file: %v\n", err) - return - } - - // Handle special commands (install, uninstall, etc.) - if len(os.Args) > 1 { - if err := handleServiceCommands(s, os.Args[1]); err != nil { - syslog.L.Errorf("Command handling failed: %v", err) - return - } - } - - // Run the service - err = s.Run() - if err != nil { - syslog.L.Errorf("Service run failed: %v", err) - // Instead of exiting, restart the service - if err := restartService(); err != nil { - syslog.L.Errorf("Service restart failed: %v", err) - } - } -} - -func restartService() error { - cmd := exec.Command("sc", "start", "PBSPlusAgent") - return cmd.Run() -} - -func handleServiceCommands(s service.Service, cmd string) error { - switch cmd { - case "version": - fmt.Print(Version) - os.Stdout.Sync() - os.Exit(0) - case "install", "uninstall": - // Clean up registry before install/uninstall - _ = registry.DeleteKey(registry.LOCAL_MACHINE, `Software\PBSPlus\Auth`) - err := service.Control(s, cmd) - if err != nil { - return fmt.Errorf("failed to %s service: %v", cmd, err) - } - if cmd == "install" { - go func() { - <-time.After(10 * time.Second) - _ = s.Start() - }() - } - // case "--set-server-url": - // if !isAdmin() { - // return fmt.Errorf("needs to be running as administrator") - // } - // if len(os.Args) > 2 { - // serverUrl := os.Args[2] - // if err := setServerURLAdmin(serverUrl); err != nil { - // return fmt.Errorf("error setting server URL: %v", err) - // } - // } - default: - err := service.Control(s, cmd) - if err != nil { - return fmt.Errorf("failed to execute command %s: %v", cmd, err) - } - } - return nil -} - -func createMutex() error { - mutex.Lock() - defer mutex.Unlock() - - // Create a unique mutex name based on the executable path - execPath, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to get executable path: %v", err) - } - mutexName := filepath.Base(execPath) - - // Try to create/acquire the named mutex - h, err := windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(mutexName)) - if err != nil { - return fmt.Errorf("failed to create mutex: %v", err) - } - - // Check if the mutex already exists - if windows.GetLastError() == syscall.ERROR_ALREADY_EXISTS { - windows.CloseHandle(h) - return fmt.Errorf("another instance is already running") - } - - handle = h - return nil -} - -func releaseMutex() { - if handle != 0 { - windows.CloseHandle(handle) - } -} +//go:build windows + +package main + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime/debug" + "sync" + "syscall" + "time" + + "github.com/kardianos/service" + "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +var Version = "v0.0.0" +var ( + mutex sync.Mutex + handle windows.Handle +) + +// watchdogService wraps the original service and adds resilience +type watchdogService struct { + *agentService + restartCount int + lastRestartTime time.Time + maxRestarts int + restartWindow time.Duration +} + +func newWatchdogService(original *agentService) *watchdogService { + return &watchdogService{ + agentService: original, + maxRestarts: 5, // Max restarts in window + restartWindow: time.Hour * 1, // Reset counter after 1 hour + } +} + +func (w *watchdogService) resetRestartCounter() { + if time.Since(w.lastRestartTime) > w.restartWindow { + w.restartCount = 0 + } +} + +func (w *watchdogService) shouldRestart() bool { + w.resetRestartCounter() + return w.restartCount < w.maxRestarts +} + +func (w *watchdogService) Start(s service.Service) error { + go func() { + for { + err := w.runWithRecovery(s) + if err != nil { + syslog.L.Errorf("Service failed with error: %v - Attempting restart", err) + + w.restartCount++ + w.lastRestartTime = time.Now() + + if !w.shouldRestart() { + syslog.L.Errorf("Too many restart attempts (%d) within window. Waiting for window reset.", w.restartCount) + time.Sleep(w.restartWindow) + w.restartCount = 0 + } + + time.Sleep(time.Second * 5) // Brief delay before restart + continue + } + break // Clean exit + } + }() + return nil +} + +func (w *watchdogService) runWithRecovery(s service.Service) (err error) { + defer func() { + if r := recover(); r != nil { + stack := string(debug.Stack()) + err = fmt.Errorf("service panicked: %v\nStack:\n%s", r, stack) + syslog.L.Error(err) + } + }() + + return w.agentService.Start(s) +} + +func (w *watchdogService) Stop(s service.Service) error { + return w.agentService.Stop(s) +} + +func main() { + constants.Version = Version + + svcConfig := &service.Config{ + Name: "PBSPlusAgent", + DisplayName: "PBS Plus Agent", + Description: "Agent for orchestrating backups with PBS Plus", + UserName: "", + } + + prg := &agentService{} + watchdog := newWatchdogService(prg) + + s, err := service.New(watchdog, svcConfig) + if err != nil { + fmt.Printf("Failed to initialize service: %v\n", err) + return + } + prg.svc = s + + err = syslog.InitializeLogger(s) + if err != nil { + fmt.Printf("Failed to initialize logger: %v\n", err) + return + } + + if err := createMutex(); err != nil { + syslog.L.Errorf("Error: %v", err) + os.Exit(1) + } + defer releaseMutex() + + err = prg.writeVersionToFile() + if err != nil { + fmt.Printf("Error writing version to file: %v\n", err) + return + } + + // Handle special commands (install, uninstall, etc.) + if len(os.Args) > 1 { + if err := handleServiceCommands(s, os.Args[1]); err != nil { + syslog.L.Errorf("Command handling failed: %v", err) + return + } + } + + // Run the service + err = s.Run() + if err != nil { + syslog.L.Errorf("Service run failed: %v", err) + // Instead of exiting, restart the service + if err := restartService(); err != nil { + syslog.L.Errorf("Service restart failed: %v", err) + } + } +} + +func restartService() error { + cmd := exec.Command("sc", "start", "PBSPlusAgent") + return cmd.Run() +} + +func handleServiceCommands(s service.Service, cmd string) error { + switch cmd { + case "version": + fmt.Print(Version) + os.Stdout.Sync() + os.Exit(0) + case "install", "uninstall": + // Clean up registry before install/uninstall + _ = registry.DeleteKey(registry.LOCAL_MACHINE, `Software\PBSPlus\Auth`) + err := service.Control(s, cmd) + if err != nil { + return fmt.Errorf("failed to %s service: %v", cmd, err) + } + if cmd == "install" { + go func() { + <-time.After(10 * time.Second) + _ = s.Start() + }() + } + // case "--set-server-url": + // if !isAdmin() { + // return fmt.Errorf("needs to be running as administrator") + // } + // if len(os.Args) > 2 { + // serverUrl := os.Args[2] + // if err := setServerURLAdmin(serverUrl); err != nil { + // return fmt.Errorf("error setting server URL: %v", err) + // } + // } + default: + err := service.Control(s, cmd) + if err != nil { + return fmt.Errorf("failed to execute command %s: %v", cmd, err) + } + } + return nil +} + +func createMutex() error { + mutex.Lock() + defer mutex.Unlock() + + // Create a unique mutex name based on the executable path + execPath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %v", err) + } + mutexName := filepath.Base(execPath) + + // Try to create/acquire the named mutex + h, err := windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(mutexName)) + if err != nil { + return fmt.Errorf("failed to create mutex: %v", err) + } + + // Check if the mutex already exists + if windows.GetLastError() == syscall.ERROR_ALREADY_EXISTS { + windows.CloseHandle(h) + return fmt.Errorf("another instance is already running") + } + + handle = h + return nil +} + +func releaseMutex() { + if handle != 0 { + windows.CloseHandle(handle) + } +} diff --git a/cmd/windows_agent/service.go b/cmd/windows_agent/service.go index f4c8668..6b2748c 100644 --- a/cmd/windows_agent/service.go +++ b/cmd/windows_agent/service.go @@ -1,273 +1,273 @@ -//go:build windows -// +build windows - -package main - -import ( - "bytes" - "context" - _ "embed" - "encoding/json" - "fmt" - "io" - "math/rand" - "net/http" - "os" - "path/filepath" - "sync" - "time" - - "github.com/alexflint/go-filemutex" - "github.com/kardianos/service" - "github.com/sonroyaalmerol/pbs-plus/internal/agent" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/controllers" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" - "github.com/sonroyaalmerol/pbs-plus/internal/websockets" -) - -type PingData struct { - Pong bool `json:"pong"` -} - -type PingResp struct { - Data PingData `json:"data"` -} - -type AgentDrivesRequest struct { - Hostname string `json:"hostname"` - Drives []utils.DriveInfo `json:"drives"` -} - -type agentService struct { - svc service.Service - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup -} - -func (p *agentService) Start(s service.Service) error { - p.svc = s - p.ctx, p.cancel = context.WithCancel(context.Background()) - - p.wg.Add(2) - go func() { - defer p.wg.Done() - p.run() - }() - go func() { - defer p.wg.Done() - for { - select { - case <-p.ctx.Done(): - return - case <-time.After(time.Hour): - err := agent.CheckAndRenewCertificate() - if err != nil { - syslog.L.Errorf("Certificate renewal manager: %v", err) - } - } - } - }() - - return nil -} - -func (p *agentService) Stop(s service.Service) error { - p.cancel() - p.wg.Wait() - return nil -} - -func (p *agentService) run() { - agent.SetStatus("Starting") - if err := p.waitForServerURL(); err != nil { - syslog.L.Errorf("Failed waiting for server URL: %v", err) - return - } - - if err := p.waitForBootstrap(); err != nil { - syslog.L.Errorf("Failed waiting for bootstrap: %v", err) - return - } - - if err := p.initializeDrives(); err != nil { - syslog.L.Errorf("Failed to initialize drives: %v", err) - return - } - - if err := p.connectWebSocket(); err != nil { - syslog.L.Errorf("WebSocket connection failed: %v", err) - return - } - - go func() { - for { - select { - case <-p.ctx.Done(): - return - case <-time.After(time.Duration(rand.Intn(60*60+1)+4*60*60) * time.Second): - // executes every 4-5 hours - _ = p.initializeDrives() - } - } - }() - - <-p.ctx.Done() -} - -func (p *agentService) waitForServerURL() error { - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - - for { - entry, err := registry.GetEntry(registry.CONFIG, "ServerURL", false) - if err == nil && entry != nil { - return nil - } - - select { - case <-p.ctx.Done(): - return fmt.Errorf("context cancelled while waiting for server URL") - case <-ticker.C: - continue - } - } -} - -func (p *agentService) waitForBootstrap() error { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - serverCA, _ := registry.GetEntry(registry.AUTH, "ServerCA", true) - cert, _ := registry.GetEntry(registry.AUTH, "Cert", true) - priv, _ := registry.GetEntry(registry.AUTH, "Priv", true) - - if serverCA != nil && cert != nil && priv != nil { - err := agent.CheckAndRenewCertificate() - if err == nil { - return nil - } - syslog.L.Errorf("Renewal error: %v", err) - } else { - err := agent.Bootstrap() - if err != nil { - syslog.L.Errorf("Bootstrap error: %v", err) - } - } - - select { - case <-p.ctx.Done(): - return fmt.Errorf("context cancelled while waiting for server URL") - case <-ticker.C: - continue - } - } -} - -func (p *agentService) initializeDrives() error { - hostname, err := os.Hostname() - if err != nil { - return fmt.Errorf("failed to get hostname: %w", err) - } - - drives, err := utils.GetLocalDrives() - if err != nil { - return fmt.Errorf("failed to get local drives list: %w", err) - } - - reqBody, err := json.Marshal(&AgentDrivesRequest{ - Hostname: hostname, - Drives: drives, - }) - if err != nil { - return fmt.Errorf("failed to marshal drive request: %w", err) - } - - resp, err := agent.ProxmoxHTTPRequest( - http.MethodPost, - "/api2/json/d2d/target/agent", - bytes.NewBuffer(reqBody), - nil, - ) - if err != nil { - return fmt.Errorf("failed to update agent drives: %w", err) - } - defer resp.Close() - _, _ = io.Copy(io.Discard, resp) - - return nil -} - -func (p *agentService) writeVersionToFile() error { - ex, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to get executable path: %w", err) - } - - versionLockPath := filepath.Join(filepath.Dir(ex), "version.lock") - mutex, err := filemutex.New(versionLockPath) - if err != nil { - return fmt.Errorf("failed to execute mutex: %w", err) - } - - mutex.Lock() - defer mutex.Unlock() - - versionFile := filepath.Join(filepath.Dir(ex), "version.txt") - err = os.WriteFile(versionFile, []byte(Version), 0644) - if err != nil { - return fmt.Errorf("failed to write version file: %w", err) - } - - return nil -} - -func (p *agentService) connectWebSocket() error { - for { - config, err := websockets.GetWindowsConfig(Version) - if err != nil { - syslog.L.Errorf("WS client windows config error: %s", err) - return err - } - - tlsConfig, err := agent.GetTLSConfig() - if err != nil { - syslog.L.Errorf("WS client tls config error: %s", err) - return err - } - - config.TLSConfig = tlsConfig - - client, err := websockets.NewWSClient(p.ctx, config) - if err != nil { - syslog.L.Errorf("WS client init error: %s", err) - select { - case <-p.ctx.Done(): - return fmt.Errorf("context cancelled while connecting to WebSocket") - case <-time.After(5 * time.Second): - continue - } - } - - client.RegisterHandler("backup_start", controllers.BackupStartHandler(client)) - client.RegisterHandler("backup_close", controllers.BackupCloseHandler(client)) - - err = client.Connect(p.ctx) - if err != nil { - syslog.L.Errorf("WS client connect error: %s", err) - select { - case <-p.ctx.Done(): - return fmt.Errorf("context cancelled while connecting to WebSocket") - case <-time.After(5 * time.Second): - continue - } - } - - break - } - - return nil -} +//go:build windows +// +build windows + +package main + +import ( + "bytes" + "context" + _ "embed" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "os" + "path/filepath" + "sync" + "time" + + "github.com/alexflint/go-filemutex" + "github.com/kardianos/service" + "github.com/sonroyaalmerol/pbs-plus/internal/agent" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/controllers" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" + "github.com/sonroyaalmerol/pbs-plus/internal/websockets" +) + +type PingData struct { + Pong bool `json:"pong"` +} + +type PingResp struct { + Data PingData `json:"data"` +} + +type AgentDrivesRequest struct { + Hostname string `json:"hostname"` + Drives []utils.DriveInfo `json:"drives"` +} + +type agentService struct { + svc service.Service + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +func (p *agentService) Start(s service.Service) error { + p.svc = s + p.ctx, p.cancel = context.WithCancel(context.Background()) + + p.wg.Add(2) + go func() { + defer p.wg.Done() + p.run() + }() + go func() { + defer p.wg.Done() + for { + select { + case <-p.ctx.Done(): + return + case <-time.After(time.Hour): + err := agent.CheckAndRenewCertificate() + if err != nil { + syslog.L.Errorf("Certificate renewal manager: %v", err) + } + } + } + }() + + return nil +} + +func (p *agentService) Stop(s service.Service) error { + p.cancel() + p.wg.Wait() + return nil +} + +func (p *agentService) run() { + agent.SetStatus("Starting") + if err := p.waitForServerURL(); err != nil { + syslog.L.Errorf("Failed waiting for server URL: %v", err) + return + } + + if err := p.waitForBootstrap(); err != nil { + syslog.L.Errorf("Failed waiting for bootstrap: %v", err) + return + } + + if err := p.initializeDrives(); err != nil { + syslog.L.Errorf("Failed to initialize drives: %v", err) + return + } + + if err := p.connectWebSocket(); err != nil { + syslog.L.Errorf("WebSocket connection failed: %v", err) + return + } + + go func() { + for { + select { + case <-p.ctx.Done(): + return + case <-time.After(time.Duration(rand.Intn(60*60+1)+4*60*60) * time.Second): + // executes every 4-5 hours + _ = p.initializeDrives() + } + } + }() + + <-p.ctx.Done() +} + +func (p *agentService) waitForServerURL() error { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + entry, err := registry.GetEntry(registry.CONFIG, "ServerURL", false) + if err == nil && entry != nil { + return nil + } + + select { + case <-p.ctx.Done(): + return fmt.Errorf("context cancelled while waiting for server URL") + case <-ticker.C: + continue + } + } +} + +func (p *agentService) waitForBootstrap() error { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + serverCA, _ := registry.GetEntry(registry.AUTH, "ServerCA", true) + cert, _ := registry.GetEntry(registry.AUTH, "Cert", true) + priv, _ := registry.GetEntry(registry.AUTH, "Priv", true) + + if serverCA != nil && cert != nil && priv != nil { + err := agent.CheckAndRenewCertificate() + if err == nil { + return nil + } + syslog.L.Errorf("Renewal error: %v", err) + } else { + err := agent.Bootstrap() + if err != nil { + syslog.L.Errorf("Bootstrap error: %v", err) + } + } + + select { + case <-p.ctx.Done(): + return fmt.Errorf("context cancelled while waiting for server URL") + case <-ticker.C: + continue + } + } +} + +func (p *agentService) initializeDrives() error { + hostname, err := os.Hostname() + if err != nil { + return fmt.Errorf("failed to get hostname: %w", err) + } + + drives, err := utils.GetLocalDrives() + if err != nil { + return fmt.Errorf("failed to get local drives list: %w", err) + } + + reqBody, err := json.Marshal(&AgentDrivesRequest{ + Hostname: hostname, + Drives: drives, + }) + if err != nil { + return fmt.Errorf("failed to marshal drive request: %w", err) + } + + resp, err := agent.ProxmoxHTTPRequest( + http.MethodPost, + "/api2/json/d2d/target/agent", + bytes.NewBuffer(reqBody), + nil, + ) + if err != nil { + return fmt.Errorf("failed to update agent drives: %w", err) + } + defer resp.Close() + _, _ = io.Copy(io.Discard, resp) + + return nil +} + +func (p *agentService) writeVersionToFile() error { + ex, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + versionLockPath := filepath.Join(filepath.Dir(ex), "version.lock") + mutex, err := filemutex.New(versionLockPath) + if err != nil { + return fmt.Errorf("failed to execute mutex: %w", err) + } + + mutex.Lock() + defer mutex.Unlock() + + versionFile := filepath.Join(filepath.Dir(ex), "version.txt") + err = os.WriteFile(versionFile, []byte(Version), 0644) + if err != nil { + return fmt.Errorf("failed to write version file: %w", err) + } + + return nil +} + +func (p *agentService) connectWebSocket() error { + for { + config, err := websockets.GetWindowsConfig(Version) + if err != nil { + syslog.L.Errorf("WS client windows config error: %s", err) + return err + } + + tlsConfig, err := agent.GetTLSConfig() + if err != nil { + syslog.L.Errorf("WS client tls config error: %s", err) + return err + } + + config.TLSConfig = tlsConfig + + client, err := websockets.NewWSClient(p.ctx, config) + if err != nil { + syslog.L.Errorf("WS client init error: %s", err) + select { + case <-p.ctx.Done(): + return fmt.Errorf("context cancelled while connecting to WebSocket") + case <-time.After(5 * time.Second): + continue + } + } + + client.RegisterHandler("backup_start", controllers.BackupStartHandler(client)) + client.RegisterHandler("backup_close", controllers.BackupCloseHandler(client)) + + err = client.Connect(p.ctx) + if err != nil { + syslog.L.Errorf("WS client connect error: %s", err) + select { + case <-p.ctx.Done(): + return fmt.Errorf("context cancelled while connecting to WebSocket") + case <-time.After(5 * time.Second): + continue + } + } + + break + } + + return nil +} diff --git a/cmd/windows_updater/main.go b/cmd/windows_updater/main.go index 13e435f..ecef124 100644 --- a/cmd/windows_updater/main.go +++ b/cmd/windows_updater/main.go @@ -1,247 +1,247 @@ -//go:build windows - -package main - -import ( - "context" - "fmt" - "net/http" - "os" - "path/filepath" - "strings" - "sync" - "syscall" - "time" - - "github.com/alexflint/go-filemutex" - "github.com/kardianos/service" - "github.com/sonroyaalmerol/pbs-plus/internal/agent" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/controllers" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "golang.org/x/sys/windows" -) - -type UpdaterService struct { - svc service.Service - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup -} - -const ( - updateCheckInterval = 2 * time.Minute -) - -var ( - mutex sync.Mutex - handle windows.Handle -) - -func (u *UpdaterService) Start(s service.Service) error { - u.svc = s - u.ctx, u.cancel = context.WithCancel(context.Background()) - - u.wg.Add(1) - go func() { - defer u.wg.Done() - u.runUpdateCheck() - }() - - return nil -} - -func (u *UpdaterService) Stop(s service.Service) error { - u.cancel() - u.wg.Wait() - return nil -} - -func (u *UpdaterService) runUpdateCheck() { - ticker := time.NewTicker(updateCheckInterval) - defer ticker.Stop() - - checkAndUpdate := func() { - hasActiveBackups, err := u.checkForActiveBackups() - if err != nil { - syslog.L.Errorf("Failed to check backup status: %v", err) - return - } - if hasActiveBackups { - syslog.L.Info("Skipping version check - backup in progress") - return - } - - newVersion, err := u.checkForNewVersion() - if err != nil { - syslog.L.Errorf("Version check failed: %v", err) - return - } - - if newVersion != "" { - mainVersion, err := u.getMainServiceVersion() - if err != nil { - syslog.L.Errorf("Failed to get main version: %v", err) - return - } - syslog.L.Infof("New version %s available, current version: %s", newVersion, mainVersion) - - // Double check before update - hasActiveBackups, _ = u.checkForActiveBackups() - if hasActiveBackups { - syslog.L.Info("Postponing update - backup started during version check") - return - } - - if err := u.performUpdate(); err != nil { - syslog.L.Errorf("Update failed: %v", err) - return - } - - syslog.L.Infof("Successfully updated to version %s", newVersion) - } - } - - // Initial check - checkAndUpdate() - - for { - select { - case <-u.ctx.Done(): - return - case <-ticker.C: - checkAndUpdate() - } - } -} - -func (u *UpdaterService) checkForActiveBackups() (bool, error) { - store := controllers.GetNFSSessionStore() - return store.HasSessions(), nil -} - -func (u *UpdaterService) checkForNewVersion() (string, error) { - var versionResp VersionResp - _, err := agent.ProxmoxHTTPRequest( - http.MethodGet, - "/api2/json/plus/version", - nil, - &versionResp, - ) - if err != nil { - return "", err - } - - mainVersion, err := u.getMainServiceVersion() - if err != nil { - return "", err - } - - if versionResp.Version != mainVersion { - return versionResp.Version, nil - } - return "", nil -} - -func main() { - svcConfig := &service.Config{ - Name: "PBSPlusUpdater", - DisplayName: "PBS Plus Updater Service", - Description: "Handles automatic updates for PBS Plus Agent", - } - - updater := &UpdaterService{} - s, err := service.New(updater, svcConfig) - if err != nil { - fmt.Printf("Failed to initialize service: %v\n", err) - return - } - - err = syslog.InitializeLogger(s) - if err != nil { - fmt.Printf("Failed to initialize logger: %v\n", err) - return - } - - if err := createMutex(); err != nil { - syslog.L.Errorf("Error: %v", err) - os.Exit(1) - } - defer releaseMutex() - - if len(os.Args) > 1 { - err = service.Control(s, os.Args[1]) - if err != nil { - fmt.Printf("Failed to execute command %s: %v\n", os.Args[1], err) - return - } - return - } - - err = s.Run() - if err != nil { - syslog.L.Errorf("Service run failed: %v", err) - } -} - -func (p *UpdaterService) readVersionFromFile() (string, error) { - ex, err := os.Executable() - if err != nil { - return "", fmt.Errorf("failed to get executable path: %w", err) - } - - versionLockPath := filepath.Join(filepath.Dir(ex), "version.lock") - mutex, err := filemutex.New(versionLockPath) - if err != nil { - return "", fmt.Errorf("failed to execute mutex: %w", err) - } - - mutex.RLock() - defer mutex.RUnlock() - - versionFile := filepath.Join(filepath.Dir(ex), "version.txt") - data, err := os.ReadFile(versionFile) - if err != nil { - return "", fmt.Errorf("failed to read version file: %w", err) - } - - version := strings.TrimSpace(string(data)) - if version == "" { - syslog.L.Errorf("Version file is empty") - return "", fmt.Errorf("version file is empty") - } - - return version, nil -} - -func createMutex() error { - mutex.Lock() - defer mutex.Unlock() - - // Create a unique mutex name based on the executable path - execPath, err := os.Executable() - if err != nil { - return fmt.Errorf("failed to get executable path: %v", err) - } - mutexName := filepath.Base(execPath) - - // Try to create/acquire the named mutex - h, err := windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(mutexName)) - if err != nil { - return fmt.Errorf("failed to create mutex: %v", err) - } - - // Check if the mutex already exists - if windows.GetLastError() == syscall.ERROR_ALREADY_EXISTS { - windows.CloseHandle(h) - return fmt.Errorf("another instance is already running") - } - - handle = h - return nil -} - -func releaseMutex() { - if handle != 0 { - windows.CloseHandle(handle) - } -} +//go:build windows + +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/alexflint/go-filemutex" + "github.com/kardianos/service" + "github.com/sonroyaalmerol/pbs-plus/internal/agent" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/controllers" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "golang.org/x/sys/windows" +) + +type UpdaterService struct { + svc service.Service + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +const ( + updateCheckInterval = 2 * time.Minute +) + +var ( + mutex sync.Mutex + handle windows.Handle +) + +func (u *UpdaterService) Start(s service.Service) error { + u.svc = s + u.ctx, u.cancel = context.WithCancel(context.Background()) + + u.wg.Add(1) + go func() { + defer u.wg.Done() + u.runUpdateCheck() + }() + + return nil +} + +func (u *UpdaterService) Stop(s service.Service) error { + u.cancel() + u.wg.Wait() + return nil +} + +func (u *UpdaterService) runUpdateCheck() { + ticker := time.NewTicker(updateCheckInterval) + defer ticker.Stop() + + checkAndUpdate := func() { + hasActiveBackups, err := u.checkForActiveBackups() + if err != nil { + syslog.L.Errorf("Failed to check backup status: %v", err) + return + } + if hasActiveBackups { + syslog.L.Info("Skipping version check - backup in progress") + return + } + + newVersion, err := u.checkForNewVersion() + if err != nil { + syslog.L.Errorf("Version check failed: %v", err) + return + } + + if newVersion != "" { + mainVersion, err := u.getMainServiceVersion() + if err != nil { + syslog.L.Errorf("Failed to get main version: %v", err) + return + } + syslog.L.Infof("New version %s available, current version: %s", newVersion, mainVersion) + + // Double check before update + hasActiveBackups, _ = u.checkForActiveBackups() + if hasActiveBackups { + syslog.L.Info("Postponing update - backup started during version check") + return + } + + if err := u.performUpdate(); err != nil { + syslog.L.Errorf("Update failed: %v", err) + return + } + + syslog.L.Infof("Successfully updated to version %s", newVersion) + } + } + + // Initial check + checkAndUpdate() + + for { + select { + case <-u.ctx.Done(): + return + case <-ticker.C: + checkAndUpdate() + } + } +} + +func (u *UpdaterService) checkForActiveBackups() (bool, error) { + store := controllers.GetNFSSessionStore() + return store.HasSessions(), nil +} + +func (u *UpdaterService) checkForNewVersion() (string, error) { + var versionResp VersionResp + _, err := agent.ProxmoxHTTPRequest( + http.MethodGet, + "/api2/json/plus/version", + nil, + &versionResp, + ) + if err != nil { + return "", err + } + + mainVersion, err := u.getMainServiceVersion() + if err != nil { + return "", err + } + + if versionResp.Version != mainVersion { + return versionResp.Version, nil + } + return "", nil +} + +func main() { + svcConfig := &service.Config{ + Name: "PBSPlusUpdater", + DisplayName: "PBS Plus Updater Service", + Description: "Handles automatic updates for PBS Plus Agent", + } + + updater := &UpdaterService{} + s, err := service.New(updater, svcConfig) + if err != nil { + fmt.Printf("Failed to initialize service: %v\n", err) + return + } + + err = syslog.InitializeLogger(s) + if err != nil { + fmt.Printf("Failed to initialize logger: %v\n", err) + return + } + + if err := createMutex(); err != nil { + syslog.L.Errorf("Error: %v", err) + os.Exit(1) + } + defer releaseMutex() + + if len(os.Args) > 1 { + err = service.Control(s, os.Args[1]) + if err != nil { + fmt.Printf("Failed to execute command %s: %v\n", os.Args[1], err) + return + } + return + } + + err = s.Run() + if err != nil { + syslog.L.Errorf("Service run failed: %v", err) + } +} + +func (p *UpdaterService) readVersionFromFile() (string, error) { + ex, err := os.Executable() + if err != nil { + return "", fmt.Errorf("failed to get executable path: %w", err) + } + + versionLockPath := filepath.Join(filepath.Dir(ex), "version.lock") + mutex, err := filemutex.New(versionLockPath) + if err != nil { + return "", fmt.Errorf("failed to execute mutex: %w", err) + } + + mutex.RLock() + defer mutex.RUnlock() + + versionFile := filepath.Join(filepath.Dir(ex), "version.txt") + data, err := os.ReadFile(versionFile) + if err != nil { + return "", fmt.Errorf("failed to read version file: %w", err) + } + + version := strings.TrimSpace(string(data)) + if version == "" { + syslog.L.Errorf("Version file is empty") + return "", fmt.Errorf("version file is empty") + } + + return version, nil +} + +func createMutex() error { + mutex.Lock() + defer mutex.Unlock() + + // Create a unique mutex name based on the executable path + execPath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %v", err) + } + mutexName := filepath.Base(execPath) + + // Try to create/acquire the named mutex + h, err := windows.CreateMutex(nil, false, windows.StringToUTF16Ptr(mutexName)) + if err != nil { + return fmt.Errorf("failed to create mutex: %v", err) + } + + // Check if the mutex already exists + if windows.GetLastError() == syscall.ERROR_ALREADY_EXISTS { + windows.CloseHandle(h) + return fmt.Errorf("another instance is already running") + } + + handle = h + return nil +} + +func releaseMutex() { + if handle != 0 { + windows.CloseHandle(handle) + } +} diff --git a/internal/agent/bootstrap.go b/internal/agent/bootstrap.go index cfdb958..6285fab 100644 --- a/internal/agent/bootstrap.go +++ b/internal/agent/bootstrap.go @@ -1,165 +1,165 @@ -//go:build windows - -package agent - -import ( - "bytes" - "crypto/tls" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "strings" - "time" - - "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" - "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -type BootstrapRequest struct { - Hostname string `json:"hostname"` - CSR string `json:"csr"` - Drives []utils.DriveInfo `json:"drives"` -} - -type BootstrapResponse struct { - Cert string `json:"cert"` - CA string `json:"ca"` -} - -func Bootstrap() error { - token, err := registry.GetEntry(registry.CONFIG, "BootstrapToken", false) - if err != nil || token == nil { - return fmt.Errorf("Bootstrap: token not found -> %w", err) - } - - serverUrl, err := registry.GetEntry(registry.CONFIG, "ServerURL", false) - if err != nil || serverUrl == nil { - return fmt.Errorf("Bootstrap: server url not found -> %w", err) - } - - hostname, _ := os.Hostname() - - csr, privKey, err := certificates.GenerateCSR(hostname, 2048) - if err != nil { - return fmt.Errorf("Bootstrap: generating csr failed -> %w", err) - } - - encodedCSR := base64.StdEncoding.EncodeToString(csr) - - drives, err := utils.GetLocalDrives() - if err != nil { - return fmt.Errorf("Bootstrap: failed to get local drives list: %w", err) - } - - reqBody, err := json.Marshal(&BootstrapRequest{ - Hostname: hostname, - Drives: drives, - CSR: encodedCSR, - }) - if err != nil { - return fmt.Errorf("failed to marshal bootstrap request: %w", err) - } - - req, err := http.NewRequest( - http.MethodPost, - fmt.Sprintf( - "%s%s", - strings.TrimSuffix(serverUrl.Value, "/"), - "/plus/agent/bootstrap", - ), - bytes.NewBuffer(reqBody), - ) - - if err != nil { - return fmt.Errorf("Bootstrap: error creating http request -> %w", err) - } - - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(token.Value))) - - if httpClient == nil { - httpClient = &http.Client{ - Timeout: time.Second * 30, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - } - } - - resp, err := httpClient.Do(req) - if err != nil { - return fmt.Errorf("Bootstrap: error executing http request -> %w", err) - } - - defer func() { - _, _ = io.Copy(io.Discard, resp.Body) - resp.Body.Close() - }() - - rawBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("Bootstrap: error getting body content -> %w", err) - } - - bootstrapResp := &BootstrapResponse{} - err = json.Unmarshal(rawBody, bootstrapResp) - if err != nil { - return fmt.Errorf("Bootstrap: error json unmarshal body content (%s) -> %w", string(rawBody), err) - } - - decodedCA, err := base64.StdEncoding.DecodeString(bootstrapResp.CA) - if err != nil { - return fmt.Errorf("Bootstrap: error decoding ca content (%s) -> %w", string(bootstrapResp.CA), err) - } - - decodedCert, err := base64.StdEncoding.DecodeString(bootstrapResp.Cert) - if err != nil { - return fmt.Errorf("Bootstrap: error decoding cert content (%s) -> %w", string(bootstrapResp.Cert), err) - } - - privKeyPEM := certificates.EncodeKeyPEM(privKey) - - caEntry := registry.RegistryEntry{ - Key: "ServerCA", - Value: string(decodedCA), - Path: registry.AUTH, - IsSecret: true, - } - - certEntry := registry.RegistryEntry{ - Key: "Cert", - Value: string(decodedCert), - Path: registry.AUTH, - IsSecret: true, - } - - privEntry := registry.RegistryEntry{ - Key: "Priv", - Value: string(privKeyPEM), - Path: registry.AUTH, - IsSecret: true, - } - - err = registry.CreateEntry(&caEntry) - if err != nil { - return fmt.Errorf("Bootstrap: error storing ca to registry -> %w", err) - } - - err = registry.CreateEntry(&certEntry) - if err != nil { - return fmt.Errorf("Bootstrap: error storing cert to registry -> %w", err) - } - - err = registry.CreateEntry(&privEntry) - if err != nil { - return fmt.Errorf("Bootstrap: error storing priv to registry -> %w", err) - } - - return nil -} +//go:build windows + +package agent + +import ( + "bytes" + "crypto/tls" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" + "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +type BootstrapRequest struct { + Hostname string `json:"hostname"` + CSR string `json:"csr"` + Drives []utils.DriveInfo `json:"drives"` +} + +type BootstrapResponse struct { + Cert string `json:"cert"` + CA string `json:"ca"` +} + +func Bootstrap() error { + token, err := registry.GetEntry(registry.CONFIG, "BootstrapToken", false) + if err != nil || token == nil { + return fmt.Errorf("Bootstrap: token not found -> %w", err) + } + + serverUrl, err := registry.GetEntry(registry.CONFIG, "ServerURL", false) + if err != nil || serverUrl == nil { + return fmt.Errorf("Bootstrap: server url not found -> %w", err) + } + + hostname, _ := os.Hostname() + + csr, privKey, err := certificates.GenerateCSR(hostname, 2048) + if err != nil { + return fmt.Errorf("Bootstrap: generating csr failed -> %w", err) + } + + encodedCSR := base64.StdEncoding.EncodeToString(csr) + + drives, err := utils.GetLocalDrives() + if err != nil { + return fmt.Errorf("Bootstrap: failed to get local drives list: %w", err) + } + + reqBody, err := json.Marshal(&BootstrapRequest{ + Hostname: hostname, + Drives: drives, + CSR: encodedCSR, + }) + if err != nil { + return fmt.Errorf("failed to marshal bootstrap request: %w", err) + } + + req, err := http.NewRequest( + http.MethodPost, + fmt.Sprintf( + "%s%s", + strings.TrimSuffix(serverUrl.Value, "/"), + "/plus/agent/bootstrap", + ), + bytes.NewBuffer(reqBody), + ) + + if err != nil { + return fmt.Errorf("Bootstrap: error creating http request -> %w", err) + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(token.Value))) + + if httpClient == nil { + httpClient = &http.Client{ + Timeout: time.Second * 30, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("Bootstrap: error executing http request -> %w", err) + } + + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + rawBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("Bootstrap: error getting body content -> %w", err) + } + + bootstrapResp := &BootstrapResponse{} + err = json.Unmarshal(rawBody, bootstrapResp) + if err != nil { + return fmt.Errorf("Bootstrap: error json unmarshal body content (%s) -> %w", string(rawBody), err) + } + + decodedCA, err := base64.StdEncoding.DecodeString(bootstrapResp.CA) + if err != nil { + return fmt.Errorf("Bootstrap: error decoding ca content (%s) -> %w", string(bootstrapResp.CA), err) + } + + decodedCert, err := base64.StdEncoding.DecodeString(bootstrapResp.Cert) + if err != nil { + return fmt.Errorf("Bootstrap: error decoding cert content (%s) -> %w", string(bootstrapResp.Cert), err) + } + + privKeyPEM := certificates.EncodeKeyPEM(privKey) + + caEntry := registry.RegistryEntry{ + Key: "ServerCA", + Value: string(decodedCA), + Path: registry.AUTH, + IsSecret: true, + } + + certEntry := registry.RegistryEntry{ + Key: "Cert", + Value: string(decodedCert), + Path: registry.AUTH, + IsSecret: true, + } + + privEntry := registry.RegistryEntry{ + Key: "Priv", + Value: string(privKeyPEM), + Path: registry.AUTH, + IsSecret: true, + } + + err = registry.CreateEntry(&caEntry) + if err != nil { + return fmt.Errorf("Bootstrap: error storing ca to registry -> %w", err) + } + + err = registry.CreateEntry(&certEntry) + if err != nil { + return fmt.Errorf("Bootstrap: error storing cert to registry -> %w", err) + } + + err = registry.CreateEntry(&privEntry) + if err != nil { + return fmt.Errorf("Bootstrap: error storing priv to registry -> %w", err) + } + + return nil +} diff --git a/internal/agent/controllers/nfssession.go b/internal/agent/controllers/nfssession.go index ba7b9b8..453044b 100644 --- a/internal/agent/controllers/nfssession.go +++ b/internal/agent/controllers/nfssession.go @@ -1,124 +1,124 @@ -//go:build windows - -package controllers - -import ( - "encoding/json" - "os" - "path/filepath" - "sync" - "time" - - "github.com/alexflint/go-filemutex" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" -) - -type NFSSessionStore struct { - mu *filemutex.FileMutex - sessions map[string]*NFSSessionData - filepath string -} - -type NFSSessionData struct { - Drive string `json:"drive"` - StartTime time.Time `json:"start_time"` -} - -var nfsSessions sync.Map - -var ( - store *NFSSessionStore - once sync.Once -) - -func GetNFSSessionStore() *NFSSessionStore { - once.Do(func() { - execPath, err := os.Executable() - if err != nil { - panic(err) - } - storePath := filepath.Join(filepath.Dir(execPath), "nfssessions.json") - storeLockPath := filepath.Join(filepath.Dir(execPath), "nfssessions.lock") - mutex, err := filemutex.New(storeLockPath) - if err != nil { - panic(err) - } - - store = &NFSSessionStore{ - sessions: make(map[string]*NFSSessionData), - filepath: storePath, - mu: mutex, - } - store.load() - }) - return store -} - -func (s *NFSSessionStore) HasSessions() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.sessions) > 0 -} - -func (s *NFSSessionStore) load() { - s.mu.Lock() - defer s.mu.Unlock() - - data, err := os.ReadFile(s.filepath) - if err != nil { - if !os.IsNotExist(err) { - syslog.L.Errorf("Error reading session store: %v", err) - } - return - } - - if err := json.Unmarshal(data, &s.sessions); err != nil { - syslog.L.Errorf("Error unmarshaling session store: %v", err) - } -} - -func (s *NFSSessionStore) save() error { - s.mu.RLock() - defer s.mu.RUnlock() - - data, err := json.MarshalIndent(s.sessions, "", " ") - if err != nil { - return err - } - - return os.WriteFile(s.filepath, data, 0644) -} - -func (s *NFSSessionStore) Store(drive string, session *nfs.NFSSession) error { - s.mu.Lock() - defer s.mu.Unlock() - - sessionData := &NFSSessionData{ - Drive: drive, - StartTime: time.Now(), - } - - s.sessions[drive] = sessionData - nfsSessions.Store(drive, session) - - return s.save() -} - -func (s *NFSSessionStore) Load(drive string) (*NFSSessionData, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - - session, ok := s.sessions[drive] - return session, ok -} - -func (s *NFSSessionStore) Delete(drive string) error { - s.mu.Lock() - defer s.mu.Unlock() - - delete(s.sessions, drive) - nfsSessions.Delete(drive) - - return s.save() -} +//go:build windows + +package controllers + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "time" + + "github.com/alexflint/go-filemutex" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" +) + +type NFSSessionStore struct { + mu *filemutex.FileMutex + sessions map[string]*NFSSessionData + filepath string +} + +type NFSSessionData struct { + Drive string `json:"drive"` + StartTime time.Time `json:"start_time"` +} + +var nfsSessions sync.Map + +var ( + store *NFSSessionStore + once sync.Once +) + +func GetNFSSessionStore() *NFSSessionStore { + once.Do(func() { + execPath, err := os.Executable() + if err != nil { + panic(err) + } + storePath := filepath.Join(filepath.Dir(execPath), "nfssessions.json") + storeLockPath := filepath.Join(filepath.Dir(execPath), "nfssessions.lock") + mutex, err := filemutex.New(storeLockPath) + if err != nil { + panic(err) + } + + store = &NFSSessionStore{ + sessions: make(map[string]*NFSSessionData), + filepath: storePath, + mu: mutex, + } + store.load() + }) + return store +} + +func (s *NFSSessionStore) HasSessions() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.sessions) > 0 +} + +func (s *NFSSessionStore) load() { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := os.ReadFile(s.filepath) + if err != nil { + if !os.IsNotExist(err) { + syslog.L.Errorf("Error reading session store: %v", err) + } + return + } + + if err := json.Unmarshal(data, &s.sessions); err != nil { + syslog.L.Errorf("Error unmarshaling session store: %v", err) + } +} + +func (s *NFSSessionStore) save() error { + s.mu.RLock() + defer s.mu.RUnlock() + + data, err := json.MarshalIndent(s.sessions, "", " ") + if err != nil { + return err + } + + return os.WriteFile(s.filepath, data, 0644) +} + +func (s *NFSSessionStore) Store(drive string, session *nfs.NFSSession) error { + s.mu.Lock() + defer s.mu.Unlock() + + sessionData := &NFSSessionData{ + Drive: drive, + StartTime: time.Now(), + } + + s.sessions[drive] = sessionData + nfsSessions.Store(drive, session) + + return s.save() +} + +func (s *NFSSessionStore) Load(drive string) (*NFSSessionData, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + session, ok := s.sessions[drive] + return session, ok +} + +func (s *NFSSessionStore) Delete(drive string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.sessions, drive) + nfsSessions.Delete(drive) + + return s.save() +} diff --git a/internal/agent/controllers/ws.go b/internal/agent/controllers/ws.go index 88d4c3f..d40eaaf 100644 --- a/internal/agent/controllers/ws.go +++ b/internal/agent/controllers/ws.go @@ -1,90 +1,90 @@ -//go:build windows - -package controllers - -import ( - "context" - "fmt" - - "github.com/sonroyaalmerol/pbs-plus/internal/agent" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "github.com/sonroyaalmerol/pbs-plus/internal/websockets" -) - -func sendResponse(c *websockets.WSClient, msgType, content string) { - response := websockets.Message{ - Type: "response-" + msgType, - Content: "Acknowledged: " + content, - } - - c.Send(context.Background(), response) -} - -func BackupStartHandler(c *websockets.WSClient) func(ctx context.Context, msg *websockets.Message) error { - return func(ctx context.Context, msg *websockets.Message) error { - drive := msg.Content - syslog.L.Infof("Received backup request for drive %s.", drive) - - store := GetNFSSessionStore() - if err := store.Delete(drive); err != nil { - syslog.L.Errorf("Error cleaning up session store: %v", err) - } - - backupStatus := agent.GetBackupStatus() - backupStatus.StartBackup(drive) - defer backupStatus.EndBackup(drive) - - snapshot, err := snapshots.Snapshot(drive) - if err != nil { - syslog.L.Errorf("snapshot error: %v", err) - return err - } - - nfsSession := nfs.NewNFSSession(context.Background(), snapshot, drive) - if nfsSession == nil { - syslog.L.Error("NFS session is nil.") - return fmt.Errorf("NFS session is nil.") - } - - if err := store.Store(drive, nfsSession); err != nil { - syslog.L.Errorf("Error storing session: %v", err) - } - - go func() { - defer func() { - if r := recover(); r != nil { - syslog.L.Errorf("Panic in NFS session for drive %s: %v", drive, r) - } - if err := store.Delete(drive); err != nil { - syslog.L.Errorf("Error cleaning up session store: %v", err) - } - backupStatus.EndBackup(drive) - }() - nfsSession.Serve() - }() - - sendResponse(c, "backup_start", drive) - return nil - } -} - -func BackupCloseHandler(c *websockets.WSClient) func(ctx context.Context, msg *websockets.Message) error { - return func(ctx context.Context, msg *websockets.Message) error { - drive := msg.Content - syslog.L.Infof("Received closure request for drive %s.", drive) - - store := GetNFSSessionStore() - if err := store.Delete(drive); err != nil { - syslog.L.Errorf("Error cleaning up session store: %v", err) - return err - } - - backupStatus := agent.GetBackupStatus() - backupStatus.EndBackup(drive) - - sendResponse(c, "backup_close", drive) - return nil - } -} +//go:build windows + +package controllers + +import ( + "context" + "fmt" + + "github.com/sonroyaalmerol/pbs-plus/internal/agent" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "github.com/sonroyaalmerol/pbs-plus/internal/websockets" +) + +func sendResponse(c *websockets.WSClient, msgType, content string) { + response := websockets.Message{ + Type: "response-" + msgType, + Content: "Acknowledged: " + content, + } + + c.Send(context.Background(), response) +} + +func BackupStartHandler(c *websockets.WSClient) func(ctx context.Context, msg *websockets.Message) error { + return func(ctx context.Context, msg *websockets.Message) error { + drive := msg.Content + syslog.L.Infof("Received backup request for drive %s.", drive) + + store := GetNFSSessionStore() + if err := store.Delete(drive); err != nil { + syslog.L.Errorf("Error cleaning up session store: %v", err) + } + + backupStatus := agent.GetBackupStatus() + backupStatus.StartBackup(drive) + defer backupStatus.EndBackup(drive) + + snapshot, err := snapshots.Snapshot(drive) + if err != nil { + syslog.L.Errorf("snapshot error: %v", err) + return err + } + + nfsSession := nfs.NewNFSSession(context.Background(), snapshot, drive) + if nfsSession == nil { + syslog.L.Error("NFS session is nil.") + return fmt.Errorf("NFS session is nil.") + } + + if err := store.Store(drive, nfsSession); err != nil { + syslog.L.Errorf("Error storing session: %v", err) + } + + go func() { + defer func() { + if r := recover(); r != nil { + syslog.L.Errorf("Panic in NFS session for drive %s: %v", drive, r) + } + if err := store.Delete(drive); err != nil { + syslog.L.Errorf("Error cleaning up session store: %v", err) + } + backupStatus.EndBackup(drive) + }() + nfsSession.Serve() + }() + + sendResponse(c, "backup_start", drive) + return nil + } +} + +func BackupCloseHandler(c *websockets.WSClient) func(ctx context.Context, msg *websockets.Message) error { + return func(ctx context.Context, msg *websockets.Message) error { + drive := msg.Content + syslog.L.Infof("Received closure request for drive %s.", drive) + + store := GetNFSSessionStore() + if err := store.Delete(drive); err != nil { + syslog.L.Errorf("Error cleaning up session store: %v", err) + return err + } + + backupStatus := agent.GetBackupStatus() + backupStatus.EndBackup(drive) + + sendResponse(c, "backup_close", drive) + return nil + } +} diff --git a/internal/agent/nfs/auth_handler.go b/internal/agent/nfs/auth_handler.go index 3ff39a3..ef13793 100644 --- a/internal/agent/nfs/auth_handler.go +++ b/internal/agent/nfs/auth_handler.go @@ -1,101 +1,101 @@ -//go:build windows - -package nfs - -import ( - "context" - "fmt" - "net" - "sync" - "time" - - "github.com/go-git/go-billy/v5" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - nfs "github.com/willscott/go-nfs" - "golang.org/x/sys/windows" -) - -type NFSHandler struct { - mu sync.Mutex - session *NFSSession -} - -// Verify Handler interface implementation -var _ nfs.Handler = (*NFSHandler)(nil) - -// ToHandle converts a filesystem path to an opaque handle -func (h *NFSHandler) ToHandle(fs billy.Filesystem, path []string) []byte { - return nil -} - -// FromHandle converts an opaque handle back to a filesystem and path -func (h *NFSHandler) FromHandle(fh []byte) (billy.Filesystem, []string, error) { - return nil, nil, nil -} - -func (h *NFSHandler) HandleLimit() int { - return -1 -} - -// InvalidateHandle - Required by interface but no-op in read-only FS -func (h *NFSHandler) InvalidateHandle(fs billy.Filesystem, fh []byte) error { - // In read-only FS, handles never become invalid - return nil -} - -func (h *NFSHandler) validateConnection(conn net.Conn) error { - remoteAddr := conn.RemoteAddr().String() - - clientIP, _, _ := net.SplitHostPort(remoteAddr) - serverIPs, _ := net.LookupHost(h.session.serverURL.Hostname()) - for _, ip := range serverIPs { - if clientIP == ip { - return nil - } - } - - return fmt.Errorf("unregistered client attempted to connect: %s", remoteAddr) -} - -func (h *NFSHandler) Mount(ctx context.Context, conn net.Conn, req nfs.MountRequest) (nfs.MountStatus, billy.Filesystem, []nfs.AuthFlavor) { - syslog.L.Infof("[NFS.Mount] Received mount request for path: %s from %s", - string(req.Dirpath), conn.RemoteAddr().String()) - - if err := h.validateConnection(conn); err != nil { - syslog.L.Errorf("[NFS.Mount] Connection validation failed: %v", err) - return nfs.MountStatusErrPerm, nil, nil - } - - syslog.L.Infof("[NFS.Mount] Mount successful, serving from: %s", h.session.Snapshot.SnapshotPath) - return nfs.MountStatusOk, h.session.FS, []nfs.AuthFlavor{nfs.AuthFlavorNull} -} - -func (h *NFSHandler) Change(fs billy.Filesystem) billy.Change { - return nil -} - -func (h *NFSHandler) FSStat(ctx context.Context, fs billy.Filesystem, stat *nfs.FSStat) error { - driveLetter := h.session.Snapshot.DriveLetter - drivePath := driveLetter + `:\` - - var totalBytes uint64 - err := windows.GetDiskFreeSpaceEx( - windows.StringToUTF16Ptr(drivePath), - nil, - &totalBytes, - nil, - ) - if err != nil { - return err - } - - stat.TotalSize = totalBytes - stat.FreeSize = 0 - stat.AvailableSize = 0 - stat.TotalFiles = 1 << 20 - stat.FreeFiles = 0 - stat.AvailableFiles = 0 - stat.CacheHint = time.Minute - - return nil -} +//go:build windows + +package nfs + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/go-git/go-billy/v5" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + nfs "github.com/willscott/go-nfs" + "golang.org/x/sys/windows" +) + +type NFSHandler struct { + mu sync.Mutex + session *NFSSession +} + +// Verify Handler interface implementation +var _ nfs.Handler = (*NFSHandler)(nil) + +// ToHandle converts a filesystem path to an opaque handle +func (h *NFSHandler) ToHandle(fs billy.Filesystem, path []string) []byte { + return nil +} + +// FromHandle converts an opaque handle back to a filesystem and path +func (h *NFSHandler) FromHandle(fh []byte) (billy.Filesystem, []string, error) { + return nil, nil, nil +} + +func (h *NFSHandler) HandleLimit() int { + return -1 +} + +// InvalidateHandle - Required by interface but no-op in read-only FS +func (h *NFSHandler) InvalidateHandle(fs billy.Filesystem, fh []byte) error { + // In read-only FS, handles never become invalid + return nil +} + +func (h *NFSHandler) validateConnection(conn net.Conn) error { + remoteAddr := conn.RemoteAddr().String() + + clientIP, _, _ := net.SplitHostPort(remoteAddr) + serverIPs, _ := net.LookupHost(h.session.serverURL.Hostname()) + for _, ip := range serverIPs { + if clientIP == ip { + return nil + } + } + + return fmt.Errorf("unregistered client attempted to connect: %s", remoteAddr) +} + +func (h *NFSHandler) Mount(ctx context.Context, conn net.Conn, req nfs.MountRequest) (nfs.MountStatus, billy.Filesystem, []nfs.AuthFlavor) { + syslog.L.Infof("[NFS.Mount] Received mount request for path: %s from %s", + string(req.Dirpath), conn.RemoteAddr().String()) + + if err := h.validateConnection(conn); err != nil { + syslog.L.Errorf("[NFS.Mount] Connection validation failed: %v", err) + return nfs.MountStatusErrPerm, nil, nil + } + + syslog.L.Infof("[NFS.Mount] Mount successful, serving from: %s", h.session.Snapshot.SnapshotPath) + return nfs.MountStatusOk, h.session.FS, []nfs.AuthFlavor{nfs.AuthFlavorNull} +} + +func (h *NFSHandler) Change(fs billy.Filesystem) billy.Change { + return nil +} + +func (h *NFSHandler) FSStat(ctx context.Context, fs billy.Filesystem, stat *nfs.FSStat) error { + driveLetter := h.session.Snapshot.DriveLetter + drivePath := driveLetter + `:\` + + var totalBytes uint64 + err := windows.GetDiskFreeSpaceEx( + windows.StringToUTF16Ptr(drivePath), + nil, + &totalBytes, + nil, + ) + if err != nil { + return err + } + + stat.TotalSize = totalBytes + stat.FreeSize = 0 + stat.AvailableSize = 0 + stat.TotalFiles = 1 << 20 + stat.FreeFiles = 0 + stat.AvailableFiles = 0 + stat.CacheHint = time.Minute + + return nil +} diff --git a/internal/agent/nfs/logging.go b/internal/agent/nfs/logging.go index 83bdd6c..4919355 100644 --- a/internal/agent/nfs/logging.go +++ b/internal/agent/nfs/logging.go @@ -1,73 +1,73 @@ -package nfs - -import ( - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - nfs "github.com/willscott/go-nfs" -) - -type nfsLogger struct { - nfs.Logger -} - -func (l *nfsLogger) Info(v ...interface{}) { - v = append([]interface{}{"[NFS.Info] "}, v...) - syslog.L.Info(v...) -} - -func (l *nfsLogger) Infof(format string, v ...interface{}) { - syslog.L.Infof("[NFS.Info] "+format, v...) -} - -func (l *nfsLogger) Print(v ...interface{}) { - v = append([]interface{}{"[NFS.Print] "}, v...) - syslog.L.Info(v...) -} - -func (l *nfsLogger) Printf(format string, v ...interface{}) { - syslog.L.Infof("[NFS.Print] "+format, v...) -} - -func (l *nfsLogger) Debug(v ...interface{}) { - v = append([]interface{}{"[NFS.Debug] "}, v...) - syslog.L.Info(v...) -} - -func (l *nfsLogger) Debugf(format string, v ...interface{}) { - syslog.L.Infof("[NFS.Debug] "+format, v...) -} - -func (l *nfsLogger) Error(v ...interface{}) { - v = append([]interface{}{"[NFS.Error] "}, v...) - syslog.L.Error(v...) -} - -func (l *nfsLogger) Errorf(format string, v ...interface{}) { - syslog.L.Errorf("[NFS.Error] "+format, v...) -} - -func (l *nfsLogger) Panic(v ...interface{}) { - v = append([]interface{}{"[NFS.Panic] "}, v...) - syslog.L.Error(v...) -} - -func (l *nfsLogger) Panicf(format string, v ...interface{}) { - syslog.L.Errorf("[NFS.Panic] "+format, v...) -} - -func (l *nfsLogger) Trace(v ...interface{}) { - v = append([]interface{}{"[NFS.Trace] "}, v...) - syslog.L.Info(v...) -} - -func (l *nfsLogger) Tracef(format string, v ...interface{}) { - syslog.L.Infof("[NFS.Trace] "+format, v...) -} - -func (l *nfsLogger) Warn(v ...interface{}) { - v = append([]interface{}{"[NFS.Warn] "}, v...) - syslog.L.Warn(v...) -} - -func (l *nfsLogger) Warnf(format string, v ...interface{}) { - syslog.L.Warnf("[NFS.Warn] "+format, v...) -} +package nfs + +import ( + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + nfs "github.com/willscott/go-nfs" +) + +type nfsLogger struct { + nfs.Logger +} + +func (l *nfsLogger) Info(v ...interface{}) { + v = append([]interface{}{"[NFS.Info] "}, v...) + syslog.L.Info(v...) +} + +func (l *nfsLogger) Infof(format string, v ...interface{}) { + syslog.L.Infof("[NFS.Info] "+format, v...) +} + +func (l *nfsLogger) Print(v ...interface{}) { + v = append([]interface{}{"[NFS.Print] "}, v...) + syslog.L.Info(v...) +} + +func (l *nfsLogger) Printf(format string, v ...interface{}) { + syslog.L.Infof("[NFS.Print] "+format, v...) +} + +func (l *nfsLogger) Debug(v ...interface{}) { + v = append([]interface{}{"[NFS.Debug] "}, v...) + syslog.L.Info(v...) +} + +func (l *nfsLogger) Debugf(format string, v ...interface{}) { + syslog.L.Infof("[NFS.Debug] "+format, v...) +} + +func (l *nfsLogger) Error(v ...interface{}) { + v = append([]interface{}{"[NFS.Error] "}, v...) + syslog.L.Error(v...) +} + +func (l *nfsLogger) Errorf(format string, v ...interface{}) { + syslog.L.Errorf("[NFS.Error] "+format, v...) +} + +func (l *nfsLogger) Panic(v ...interface{}) { + v = append([]interface{}{"[NFS.Panic] "}, v...) + syslog.L.Error(v...) +} + +func (l *nfsLogger) Panicf(format string, v ...interface{}) { + syslog.L.Errorf("[NFS.Panic] "+format, v...) +} + +func (l *nfsLogger) Trace(v ...interface{}) { + v = append([]interface{}{"[NFS.Trace] "}, v...) + syslog.L.Info(v...) +} + +func (l *nfsLogger) Tracef(format string, v ...interface{}) { + syslog.L.Infof("[NFS.Trace] "+format, v...) +} + +func (l *nfsLogger) Warn(v ...interface{}) { + v = append([]interface{}{"[NFS.Warn] "}, v...) + syslog.L.Warn(v...) +} + +func (l *nfsLogger) Warnf(format string, v ...interface{}) { + syslog.L.Warnf("[NFS.Warn] "+format, v...) +} diff --git a/internal/agent/nfs/nfs.go b/internal/agent/nfs/nfs.go index e1ad66c..e569c8f 100644 --- a/internal/agent/nfs/nfs.go +++ b/internal/agent/nfs/nfs.go @@ -1,110 +1,110 @@ -//go:build windows -// +build windows - -package nfs - -import ( - "context" - "fmt" - "net" - "net/url" - "sync" - - "github.com/go-git/go-billy/v5" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs/vssfs" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" - nfs "github.com/willscott/go-nfs" -) - -type NFSSession struct { - Context context.Context - ctxCancel context.CancelFunc - Snapshot *snapshots.WinVSSSnapshot - DriveLetter string - listener net.Listener - connections sync.WaitGroup - isRunning bool - serverURL *url.URL - FS billy.Filesystem - statusMu sync.RWMutex -} - -func NewNFSSession(ctx context.Context, snapshot *snapshots.WinVSSSnapshot, driveLetter string) *NFSSession { - cancellableCtx, cancel := context.WithCancel(ctx) - - urlStr, err := registry.GetEntry(registry.CONFIG, "ServerURL", false) - if err != nil { - syslog.L.Errorf("[NewNFSSession] unable to get server url: %v", err) - - cancel() - return nil - } - - parsedURL, _ := url.Parse(urlStr.Value) - - return &NFSSession{ - Context: cancellableCtx, - Snapshot: snapshot, - DriveLetter: driveLetter, - ctxCancel: cancel, - isRunning: true, - serverURL: parsedURL, - FS: vssfs.NewVSSFS( - snapshot, - "/", - ), - } -} - -func (s *NFSSession) Close() { - s.statusMu.Lock() - s.isRunning = false - s.statusMu.Unlock() - - s.ctxCancel() - if s.listener != nil { - s.listener.Close() - } - s.connections.Wait() - s.Snapshot.Close() -} - -func (s *NFSSession) Serve() error { - port, err := utils.DriveLetterPort([]rune(s.DriveLetter)[0]) - if err != nil { - return fmt.Errorf("unable to determine port number: %v", err) - } - - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%s", port)) - if err != nil { - return fmt.Errorf("failed to start listener: %v", err) - } - s.listener = listener - defer listener.Close() - - handler := &NFSHandler{ - session: s, - } - - // nfs.SetLogger(&nfsLogger{}) - - syslog.L.Infof("[NFS.Serve] Serving NFS on port %s", port) - - vssHandler, err := vssfs.NewVSSIDHandler(s.FS.(*vssfs.VSSFS), handler) - if err != nil { - return fmt.Errorf("unable to handle nfs: %w", err) - } - defer vssHandler.ClearHandles() - - return nfs.Serve(listener, vssHandler) -} - -func (s *NFSSession) IsRunning() bool { - s.statusMu.RLock() - defer s.statusMu.RUnlock() - - return s.isRunning -} +//go:build windows +// +build windows + +package nfs + +import ( + "context" + "fmt" + "net" + "net/url" + "sync" + + "github.com/go-git/go-billy/v5" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs/vssfs" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" + nfs "github.com/willscott/go-nfs" +) + +type NFSSession struct { + Context context.Context + ctxCancel context.CancelFunc + Snapshot *snapshots.WinVSSSnapshot + DriveLetter string + listener net.Listener + connections sync.WaitGroup + isRunning bool + serverURL *url.URL + FS billy.Filesystem + statusMu sync.RWMutex +} + +func NewNFSSession(ctx context.Context, snapshot *snapshots.WinVSSSnapshot, driveLetter string) *NFSSession { + cancellableCtx, cancel := context.WithCancel(ctx) + + urlStr, err := registry.GetEntry(registry.CONFIG, "ServerURL", false) + if err != nil { + syslog.L.Errorf("[NewNFSSession] unable to get server url: %v", err) + + cancel() + return nil + } + + parsedURL, _ := url.Parse(urlStr.Value) + + return &NFSSession{ + Context: cancellableCtx, + Snapshot: snapshot, + DriveLetter: driveLetter, + ctxCancel: cancel, + isRunning: true, + serverURL: parsedURL, + FS: vssfs.NewVSSFS( + snapshot, + "/", + ), + } +} + +func (s *NFSSession) Close() { + s.statusMu.Lock() + s.isRunning = false + s.statusMu.Unlock() + + s.ctxCancel() + if s.listener != nil { + s.listener.Close() + } + s.connections.Wait() + s.Snapshot.Close() +} + +func (s *NFSSession) Serve() error { + port, err := utils.DriveLetterPort([]rune(s.DriveLetter)[0]) + if err != nil { + return fmt.Errorf("unable to determine port number: %v", err) + } + + listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%s", port)) + if err != nil { + return fmt.Errorf("failed to start listener: %v", err) + } + s.listener = listener + defer listener.Close() + + handler := &NFSHandler{ + session: s, + } + + // nfs.SetLogger(&nfsLogger{}) + + syslog.L.Infof("[NFS.Serve] Serving NFS on port %s", port) + + vssHandler, err := vssfs.NewVSSIDHandler(s.FS.(*vssfs.VSSFS), handler) + if err != nil { + return fmt.Errorf("unable to handle nfs: %w", err) + } + defer vssHandler.ClearHandles() + + return nfs.Serve(listener, vssHandler) +} + +func (s *NFSSession) IsRunning() bool { + s.statusMu.RLock() + defer s.statusMu.RUnlock() + + return s.isRunning +} diff --git a/internal/agent/nfs/vssfs/helpers.go b/internal/agent/nfs/vssfs/helpers.go index cb98ae6..11a22d6 100644 --- a/internal/agent/nfs/vssfs/helpers.go +++ b/internal/agent/nfs/vssfs/helpers.go @@ -1,16 +1,16 @@ -//go:build windows - -package vssfs - -import ( - "golang.org/x/sys/windows" -) - -func skipPathWithAttributes(attrs uint32) bool { - return attrs&(windows.FILE_ATTRIBUTE_REPARSE_POINT| - windows.FILE_ATTRIBUTE_DEVICE| - windows.FILE_ATTRIBUTE_OFFLINE| - windows.FILE_ATTRIBUTE_VIRTUAL| - windows.FILE_ATTRIBUTE_RECALL_ON_OPEN| - windows.FILE_ATTRIBUTE_RECALL_ON_DATA_ACCESS) != 0 -} +//go:build windows + +package vssfs + +import ( + "golang.org/x/sys/windows" +) + +func skipPathWithAttributes(attrs uint32) bool { + return attrs&(windows.FILE_ATTRIBUTE_REPARSE_POINT| + windows.FILE_ATTRIBUTE_DEVICE| + windows.FILE_ATTRIBUTE_OFFLINE| + windows.FILE_ATTRIBUTE_VIRTUAL| + windows.FILE_ATTRIBUTE_RECALL_ON_OPEN| + windows.FILE_ATTRIBUTE_RECALL_ON_DATA_ACCESS) != 0 +} diff --git a/internal/agent/nfs/vssfs/vssfs.go b/internal/agent/nfs/vssfs/vssfs.go index 1a2962b..3fc54a5 100644 --- a/internal/agent/nfs/vssfs/vssfs.go +++ b/internal/agent/nfs/vssfs/vssfs.go @@ -1,236 +1,236 @@ -//go:build windows -// +build windows - -package vssfs - -import ( - "fmt" - "os" - "path/filepath" - "strings" - "time" - - securejoin "github.com/cyphar/filepath-securejoin" - "github.com/go-git/go-billy/v5" - "github.com/go-git/go-billy/v5/osfs" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs/windows_utils" - "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" - "golang.org/x/sys/windows" -) - -// VSSFS extends osfs while enforcing read-only operations -type VSSFS struct { - billy.Filesystem - snapshot *snapshots.WinVSSSnapshot - root string -} - -var _ billy.Filesystem = (*VSSFS)(nil) - -func NewVSSFS(snapshot *snapshots.WinVSSSnapshot, baseDir string) billy.Filesystem { - fs := &VSSFS{ - Filesystem: osfs.New(filepath.Join(snapshot.SnapshotPath, baseDir), osfs.WithBoundOS()), - snapshot: snapshot, - root: filepath.Join(snapshot.SnapshotPath, baseDir), - } - - return fs -} - -// Override write operations to return read-only errors -func (fs *VSSFS) Create(filename string) (billy.File, error) { - return nil, fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Open(filename string) (billy.File, error) { - return fs.OpenFile(filename, os.O_RDONLY, 0) -} - -func (fs *VSSFS) OpenFile(filename string, flag int, perm os.FileMode) (billy.File, error) { - if flag&(os.O_WRONLY|os.O_RDWR|os.O_APPEND|os.O_CREATE|os.O_TRUNC) != 0 { - return nil, fmt.Errorf("filesystem is read-only") - } - - path, err := fs.abs(filename) - if err != nil { - return nil, err - } - - pathp, err := windows.UTF16PtrFromString(path) - if err != nil { - return nil, err - } - - handle, err := windows.CreateFile( - pathp, - windows.GENERIC_READ, - windows.FILE_SHARE_READ, - nil, - windows.OPEN_EXISTING, - windows.FILE_FLAG_BACKUP_SEMANTICS|windows.FILE_FLAG_SEQUENTIAL_SCAN, - 0, - ) - if err != nil { - return nil, err - } - - return &vssfile{File: os.NewFile(uintptr(handle), path)}, nil -} - -func (fs *VSSFS) Rename(oldpath, newpath string) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Remove(filename string) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) MkdirAll(filename string, perm os.FileMode) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Symlink(target, link string) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) TempFile(dir, prefix string) (billy.File, error) { - return nil, fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Chmod(name string, mode os.FileMode) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Lchown(name string, uid, gid int) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Chown(name string, uid, gid int) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Chtimes(name string, atime time.Time, mtime time.Time) error { - return fmt.Errorf("filesystem is read-only") -} - -func (fs *VSSFS) Lstat(filename string) (os.FileInfo, error) { - return fs.Stat(filename) -} - -func (fs *VSSFS) Stat(filename string) (os.FileInfo, error) { - windowsPath := filepath.FromSlash(filename) - fullPath, err := fs.abs(filename) - if err != nil { - return nil, err - } - - if filename == "." || filename == "" { - fullPath = fs.root - windowsPath = "." - } - - pathPtr, err := windows.UTF16PtrFromString(fullPath) - if err != nil { - return nil, err - } - - var findData windows.Win32finddata - handle, err := windows.FindFirstFile(pathPtr, &findData) - if err != nil { - return nil, mapWinError(err, filename) - } - defer windows.FindClose(handle) - - foundName := windows.UTF16ToString(findData.FileName[:]) - expectedName := filepath.Base(fullPath) - if filename == "." { - expectedName = foundName - } - - if !strings.EqualFold(foundName, expectedName) { - return nil, os.ErrNotExist - } - - // Use foundName as the file name for FileInfo - name := foundName - if filename == "." { - name = "." - } - if filename == "/" { - name = "/" - } - - info := createFileInfoFromFindData(name, windowsPath, &findData) - - return info, nil -} - -func (fs *VSSFS) ReadDir(dirname string) ([]os.FileInfo, error) { - windowsDir := filepath.FromSlash(dirname) - fullDirPath, err := fs.abs(windowsDir) - if err != nil { - return nil, err - } - - if dirname == "." || dirname == "" { - windowsDir = "." - fullDirPath = fs.root - } - searchPath := filepath.Join(fullDirPath, "*") - var findData windows.Win32finddata - handle, err := windows_utils.FindFirstFileEx(searchPath, &findData) - if err != nil { - return nil, mapWinError(err, dirname) - } - defer windows.FindClose(handle) - - var entries []os.FileInfo - for { - name := windows.UTF16ToString(findData.FileName[:]) - if name != "." && name != ".." { - winEntryPath := filepath.Join(windowsDir, name) - if !skipPathWithAttributes(findData.FileAttributes) { - info := createFileInfoFromFindData(name, winEntryPath, &findData) - entries = append(entries, info) - } - } - - if err := windows.FindNextFile(handle, &findData); err != nil { - if err == windows.ERROR_NO_MORE_FILES { - break - } - return nil, err - } - } - return entries, nil -} - -func mapWinError(err error, path string) error { - switch err { - case windows.ERROR_FILE_NOT_FOUND: - return os.ErrNotExist - case windows.ERROR_PATH_NOT_FOUND: - return os.ErrNotExist - case windows.ERROR_ACCESS_DENIED: - return os.ErrPermission - default: - return &os.PathError{ - Op: "access", - Path: path, - Err: err, - } - } -} - -func (fs *VSSFS) abs(filename string) (string, error) { - if filename == fs.root { - filename = string(filepath.Separator) - } - - path, err := securejoin.SecureJoin(fs.root, filename) - if err != nil { - return "", nil - } - - return path, nil -} +//go:build windows +// +build windows + +package vssfs + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" + + securejoin "github.com/cyphar/filepath-securejoin" + "github.com/go-git/go-billy/v5" + "github.com/go-git/go-billy/v5/osfs" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/nfs/windows_utils" + "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" + "golang.org/x/sys/windows" +) + +// VSSFS extends osfs while enforcing read-only operations +type VSSFS struct { + billy.Filesystem + snapshot *snapshots.WinVSSSnapshot + root string +} + +var _ billy.Filesystem = (*VSSFS)(nil) + +func NewVSSFS(snapshot *snapshots.WinVSSSnapshot, baseDir string) billy.Filesystem { + fs := &VSSFS{ + Filesystem: osfs.New(filepath.Join(snapshot.SnapshotPath, baseDir), osfs.WithBoundOS()), + snapshot: snapshot, + root: filepath.Join(snapshot.SnapshotPath, baseDir), + } + + return fs +} + +// Override write operations to return read-only errors +func (fs *VSSFS) Create(filename string) (billy.File, error) { + return nil, fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Open(filename string) (billy.File, error) { + return fs.OpenFile(filename, os.O_RDONLY, 0) +} + +func (fs *VSSFS) OpenFile(filename string, flag int, perm os.FileMode) (billy.File, error) { + if flag&(os.O_WRONLY|os.O_RDWR|os.O_APPEND|os.O_CREATE|os.O_TRUNC) != 0 { + return nil, fmt.Errorf("filesystem is read-only") + } + + path, err := fs.abs(filename) + if err != nil { + return nil, err + } + + pathp, err := windows.UTF16PtrFromString(path) + if err != nil { + return nil, err + } + + handle, err := windows.CreateFile( + pathp, + windows.GENERIC_READ, + windows.FILE_SHARE_READ, + nil, + windows.OPEN_EXISTING, + windows.FILE_FLAG_BACKUP_SEMANTICS|windows.FILE_FLAG_SEQUENTIAL_SCAN, + 0, + ) + if err != nil { + return nil, err + } + + return &vssfile{File: os.NewFile(uintptr(handle), path)}, nil +} + +func (fs *VSSFS) Rename(oldpath, newpath string) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Remove(filename string) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) MkdirAll(filename string, perm os.FileMode) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Symlink(target, link string) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) TempFile(dir, prefix string) (billy.File, error) { + return nil, fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Chmod(name string, mode os.FileMode) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Lchown(name string, uid, gid int) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Chown(name string, uid, gid int) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Chtimes(name string, atime time.Time, mtime time.Time) error { + return fmt.Errorf("filesystem is read-only") +} + +func (fs *VSSFS) Lstat(filename string) (os.FileInfo, error) { + return fs.Stat(filename) +} + +func (fs *VSSFS) Stat(filename string) (os.FileInfo, error) { + windowsPath := filepath.FromSlash(filename) + fullPath, err := fs.abs(filename) + if err != nil { + return nil, err + } + + if filename == "." || filename == "" { + fullPath = fs.root + windowsPath = "." + } + + pathPtr, err := windows.UTF16PtrFromString(fullPath) + if err != nil { + return nil, err + } + + var findData windows.Win32finddata + handle, err := windows.FindFirstFile(pathPtr, &findData) + if err != nil { + return nil, mapWinError(err, filename) + } + defer windows.FindClose(handle) + + foundName := windows.UTF16ToString(findData.FileName[:]) + expectedName := filepath.Base(fullPath) + if filename == "." { + expectedName = foundName + } + + if !strings.EqualFold(foundName, expectedName) { + return nil, os.ErrNotExist + } + + // Use foundName as the file name for FileInfo + name := foundName + if filename == "." { + name = "." + } + if filename == "/" { + name = "/" + } + + info := createFileInfoFromFindData(name, windowsPath, &findData) + + return info, nil +} + +func (fs *VSSFS) ReadDir(dirname string) ([]os.FileInfo, error) { + windowsDir := filepath.FromSlash(dirname) + fullDirPath, err := fs.abs(windowsDir) + if err != nil { + return nil, err + } + + if dirname == "." || dirname == "" { + windowsDir = "." + fullDirPath = fs.root + } + searchPath := filepath.Join(fullDirPath, "*") + var findData windows.Win32finddata + handle, err := windows_utils.FindFirstFileEx(searchPath, &findData) + if err != nil { + return nil, mapWinError(err, dirname) + } + defer windows.FindClose(handle) + + var entries []os.FileInfo + for { + name := windows.UTF16ToString(findData.FileName[:]) + if name != "." && name != ".." { + winEntryPath := filepath.Join(windowsDir, name) + if !skipPathWithAttributes(findData.FileAttributes) { + info := createFileInfoFromFindData(name, winEntryPath, &findData) + entries = append(entries, info) + } + } + + if err := windows.FindNextFile(handle, &findData); err != nil { + if err == windows.ERROR_NO_MORE_FILES { + break + } + return nil, err + } + } + return entries, nil +} + +func mapWinError(err error, path string) error { + switch err { + case windows.ERROR_FILE_NOT_FOUND: + return os.ErrNotExist + case windows.ERROR_PATH_NOT_FOUND: + return os.ErrNotExist + case windows.ERROR_ACCESS_DENIED: + return os.ErrPermission + default: + return &os.PathError{ + Op: "access", + Path: path, + Err: err, + } + } +} + +func (fs *VSSFS) abs(filename string) (string, error) { + if filename == fs.root { + filename = string(filepath.Separator) + } + + path, err := securejoin.SecureJoin(fs.root, filename) + if err != nil { + return "", nil + } + + return path, nil +} diff --git a/internal/agent/nfs/vssfs/vssfs_test.go b/internal/agent/nfs/vssfs/vssfs_test.go index ab5048a..f0ea2f7 100644 --- a/internal/agent/nfs/vssfs/vssfs_test.go +++ b/internal/agent/nfs/vssfs/vssfs_test.go @@ -1,153 +1,153 @@ -//go:build windows -// +build windows - -package vssfs - -import ( - "os" - "path/filepath" - "testing" - - "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/willscott/go-nfs/file" - "golang.org/x/sys/windows" -) - -func setupTestEnvironment(t *testing.T) (string, *snapshots.WinVSSSnapshot, func()) { - tempDir, err := os.MkdirTemp("", "vssfs-test-") - require.NoError(t, err) - - // Create test directory structure using Windows paths - dirs := []string{ - "testdata", - "testdata/subdir", - "testdata/excluded_dir", - } - - files := []string{ - "testdata/regular_file.txt", - "testdata/subdir/file_in_subdir.txt", - "testdata/system_file.txt", - } - - for _, dir := range dirs { - err := os.MkdirAll(filepath.Join(tempDir, dir), 0755) - require.NoError(t, err) - } - - for _, file := range files { - err := os.WriteFile(filepath.Join(tempDir, file), []byte("test"), 0644) - require.NoError(t, err) - } - - // Set system attribute on test file - systemFile := filepath.Join(tempDir, "testdata/system_file.txt") - err = windows.SetFileAttributes( - windows.StringToUTF16Ptr(systemFile), - windows.FILE_ATTRIBUTE_SYSTEM, - ) - require.NoError(t, err) - - snapshot := &snapshots.WinVSSSnapshot{ - SnapshotPath: tempDir, - } - - cleanup := func() { - os.RemoveAll(tempDir) - } - - return tempDir, snapshot, cleanup -} - -func TestStat(t *testing.T) { - _, snapshot, cleanup := setupTestEnvironment(t) - defer cleanup() - - fs := NewVSSFS(snapshot, "testdata").(*VSSFS) - - t.Run("regular file", func(t *testing.T) { - info, err := fs.Stat("regular_file.txt") - require.NoError(t, err) - assert.False(t, info.IsDir()) - assert.Equal(t, "regular_file.txt", info.Name()) - }) - - t.Run("directory", func(t *testing.T) { - info, err := fs.Stat("subdir") - require.NoError(t, err) - assert.True(t, info.IsDir()) - }) - - t.Run("root directory", func(t *testing.T) { - info, err := fs.Stat("/") - require.NoError(t, err) - assert.True(t, info.IsDir()) - assert.Equal(t, "/", info.Name()) - }) - - t.Run("current directory", func(t *testing.T) { - info, err := fs.Stat(".") - require.NoError(t, err) - assert.True(t, info.IsDir()) - assert.Equal(t, ".", info.Name()) - }) -} - -func TestReadDir(t *testing.T) { - _, snapshot, cleanup := setupTestEnvironment(t) - defer cleanup() - - fs := NewVSSFS(snapshot, "testdata").(*VSSFS) - - t.Run("root directory listing", func(t *testing.T) { - entries, err := fs.ReadDir("/") - require.NoError(t, err) - - names := make([]string, len(entries)) - for i, e := range entries { - names[i] = e.Name() - } - - assert.ElementsMatch(t, []string{"regular_file.txt", "subdir", "system_file.txt", "excluded_dir"}, names) - }) -} - -func TestPathHandling(t *testing.T) { - _, snapshot, cleanup := setupTestEnvironment(t) - defer cleanup() - fs := NewVSSFS(snapshot, "testdata").(*VSSFS) - - t.Run("mixed slashes in path", func(t *testing.T) { - info, err := fs.Stat("subdir\\file_in_subdir.txt") - require.NoError(t, err) - assert.Equal(t, "file_in_subdir.txt", info.Name()) - }) - - t.Run("relative path resolution", func(t *testing.T) { - info, err := fs.Stat("./subdir/../regular_file.txt") - require.NoError(t, err) - assert.Equal(t, "regular_file.txt", info.Name()) - }) -} - -func TestNFSMetadata(t *testing.T) { - _, snapshot, cleanup := setupTestEnvironment(t) - defer cleanup() - fs := NewVSSFS(snapshot, "testdata").(*VSSFS) - - t.Run("file metadata", func(t *testing.T) { - info, err := fs.Stat("regular_file.txt") - require.NoError(t, err) - sys := info.(*VSSFileInfo).Sys().(file.FileInfo) - assert.NotZero(t, sys.Fileid) - }) - - t.Run("directory metadata", func(t *testing.T) { - info, err := fs.Stat("subdir") - require.NoError(t, err) - sys := info.(*VSSFileInfo).Sys().(file.FileInfo) - assert.Equal(t, uint32(2), sys.Nlink) - }) -} +//go:build windows +// +build windows + +package vssfs + +import ( + "os" + "path/filepath" + "testing" + + "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/willscott/go-nfs/file" + "golang.org/x/sys/windows" +) + +func setupTestEnvironment(t *testing.T) (string, *snapshots.WinVSSSnapshot, func()) { + tempDir, err := os.MkdirTemp("", "vssfs-test-") + require.NoError(t, err) + + // Create test directory structure using Windows paths + dirs := []string{ + "testdata", + "testdata/subdir", + "testdata/excluded_dir", + } + + files := []string{ + "testdata/regular_file.txt", + "testdata/subdir/file_in_subdir.txt", + "testdata/system_file.txt", + } + + for _, dir := range dirs { + err := os.MkdirAll(filepath.Join(tempDir, dir), 0755) + require.NoError(t, err) + } + + for _, file := range files { + err := os.WriteFile(filepath.Join(tempDir, file), []byte("test"), 0644) + require.NoError(t, err) + } + + // Set system attribute on test file + systemFile := filepath.Join(tempDir, "testdata/system_file.txt") + err = windows.SetFileAttributes( + windows.StringToUTF16Ptr(systemFile), + windows.FILE_ATTRIBUTE_SYSTEM, + ) + require.NoError(t, err) + + snapshot := &snapshots.WinVSSSnapshot{ + SnapshotPath: tempDir, + } + + cleanup := func() { + os.RemoveAll(tempDir) + } + + return tempDir, snapshot, cleanup +} + +func TestStat(t *testing.T) { + _, snapshot, cleanup := setupTestEnvironment(t) + defer cleanup() + + fs := NewVSSFS(snapshot, "testdata").(*VSSFS) + + t.Run("regular file", func(t *testing.T) { + info, err := fs.Stat("regular_file.txt") + require.NoError(t, err) + assert.False(t, info.IsDir()) + assert.Equal(t, "regular_file.txt", info.Name()) + }) + + t.Run("directory", func(t *testing.T) { + info, err := fs.Stat("subdir") + require.NoError(t, err) + assert.True(t, info.IsDir()) + }) + + t.Run("root directory", func(t *testing.T) { + info, err := fs.Stat("/") + require.NoError(t, err) + assert.True(t, info.IsDir()) + assert.Equal(t, "/", info.Name()) + }) + + t.Run("current directory", func(t *testing.T) { + info, err := fs.Stat(".") + require.NoError(t, err) + assert.True(t, info.IsDir()) + assert.Equal(t, ".", info.Name()) + }) +} + +func TestReadDir(t *testing.T) { + _, snapshot, cleanup := setupTestEnvironment(t) + defer cleanup() + + fs := NewVSSFS(snapshot, "testdata").(*VSSFS) + + t.Run("root directory listing", func(t *testing.T) { + entries, err := fs.ReadDir("/") + require.NoError(t, err) + + names := make([]string, len(entries)) + for i, e := range entries { + names[i] = e.Name() + } + + assert.ElementsMatch(t, []string{"regular_file.txt", "subdir", "system_file.txt", "excluded_dir"}, names) + }) +} + +func TestPathHandling(t *testing.T) { + _, snapshot, cleanup := setupTestEnvironment(t) + defer cleanup() + fs := NewVSSFS(snapshot, "testdata").(*VSSFS) + + t.Run("mixed slashes in path", func(t *testing.T) { + info, err := fs.Stat("subdir\\file_in_subdir.txt") + require.NoError(t, err) + assert.Equal(t, "file_in_subdir.txt", info.Name()) + }) + + t.Run("relative path resolution", func(t *testing.T) { + info, err := fs.Stat("./subdir/../regular_file.txt") + require.NoError(t, err) + assert.Equal(t, "regular_file.txt", info.Name()) + }) +} + +func TestNFSMetadata(t *testing.T) { + _, snapshot, cleanup := setupTestEnvironment(t) + defer cleanup() + fs := NewVSSFS(snapshot, "testdata").(*VSSFS) + + t.Run("file metadata", func(t *testing.T) { + info, err := fs.Stat("regular_file.txt") + require.NoError(t, err) + sys := info.(*VSSFileInfo).Sys().(file.FileInfo) + assert.NotZero(t, sys.Fileid) + }) + + t.Run("directory metadata", func(t *testing.T) { + info, err := fs.Stat("subdir") + require.NoError(t, err) + sys := info.(*VSSFileInfo).Sys().(file.FileInfo) + assert.Equal(t, uint32(2), sys.Nlink) + }) +} diff --git a/internal/agent/registry/paths.go b/internal/agent/registry/paths.go index d03e0d4..6ccc2eb 100644 --- a/internal/agent/registry/paths.go +++ b/internal/agent/registry/paths.go @@ -1,4 +1,4 @@ -package registry - -const CONFIG = "Software\\PBSPlus\\Config" -const AUTH = "Software\\PBSPlus\\Auth" +package registry + +const CONFIG = "Software\\PBSPlus\\Config" +const AUTH = "Software\\PBSPlus\\Auth" diff --git a/internal/agent/registry/registry.go b/internal/agent/registry/registry.go index 9452ad1..d67de3c 100644 --- a/internal/agent/registry/registry.go +++ b/internal/agent/registry/registry.go @@ -1,127 +1,127 @@ -//go:build windows - -package registry - -import ( - "fmt" - - "github.com/billgraziano/dpapi" - "golang.org/x/sys/windows/registry" -) - -type RegistryEntry struct { - Path string - Key string - Value string - IsSecret bool -} - -// GetEntry retrieves a registry entry -func GetEntry(path string, key string, isSecret bool) (*RegistryEntry, error) { - baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) - if err != nil { - return nil, fmt.Errorf("GetEntry error: %w", err) - } - defer baseKey.Close() - - value, _, err := baseKey.GetStringValue(key) - if err != nil { - return nil, fmt.Errorf("GetEntry error: %w", err) - } - - if isSecret { - value, err = dpapi.Decrypt(value) - if err != nil { - return nil, fmt.Errorf("GetEntry error: %w", err) - } - } - - return &RegistryEntry{ - Path: path, - Key: key, - Value: value, - IsSecret: isSecret, - }, nil -} - -// CreateEntry creates a new registry entry -func CreateEntry(entry *RegistryEntry) error { - baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, entry.Path, registry.SET_VALUE) - if err != nil { - // If the key doesn't exist, create it - baseKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, entry.Path, registry.SET_VALUE) - if err != nil { - return fmt.Errorf("CreateEntry error creating key: %w", err) - } - } - defer baseKey.Close() - - value := entry.Value - if entry.IsSecret { - encrypted, err := dpapi.Encrypt(value) - if err != nil { - return fmt.Errorf("CreateEntry error encrypting: %w", err) - } - value = encrypted - } - - err = baseKey.SetStringValue(entry.Key, value) - if err != nil { - return fmt.Errorf("CreateEntry error setting value: %w", err) - } - - return nil -} - -// UpdateEntry updates an existing registry entry -func UpdateEntry(entry *RegistryEntry) error { - // First check if the entry exists - _, err := GetEntry(entry.Path, entry.Key, entry.IsSecret) - if err != nil { - return fmt.Errorf("UpdateEntry error: entry does not exist: %w", err) - } - - // Reuse CreateEntry logic for the update - return CreateEntry(entry) -} - -// DeleteEntry deletes a registry entry -func DeleteEntry(path string, key string) error { - baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.SET_VALUE) - if err != nil { - return fmt.Errorf("DeleteEntry error opening key: %w", err) - } - defer baseKey.Close() - - err = baseKey.DeleteValue(key) - if err != nil { - return fmt.Errorf("DeleteEntry error deleting value: %w", err) - } - - return nil -} - -// DeleteKey deletes an entire registry key and all its values -func DeleteKey(path string) error { - err := registry.DeleteKey(registry.LOCAL_MACHINE, path) - if err != nil { - return fmt.Errorf("DeleteKey error: %w", err) - } - return nil -} - -// ListEntries lists all values in a registry key -func ListEntries(path string) ([]string, error) { - baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) - if err != nil { - return nil, fmt.Errorf("ListEntries error opening key: %w", err) - } - defer baseKey.Close() - - valueNames, err := baseKey.ReadValueNames(0) - if err != nil { - return nil, fmt.Errorf("ListEntries error reading values: %w", err) - } - - return valueNames, nil -} +//go:build windows + +package registry + +import ( + "fmt" + + "github.com/billgraziano/dpapi" + "golang.org/x/sys/windows/registry" +) + +type RegistryEntry struct { + Path string + Key string + Value string + IsSecret bool +} + +// GetEntry retrieves a registry entry +func GetEntry(path string, key string, isSecret bool) (*RegistryEntry, error) { + baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + return nil, fmt.Errorf("GetEntry error: %w", err) + } + defer baseKey.Close() + + value, _, err := baseKey.GetStringValue(key) + if err != nil { + return nil, fmt.Errorf("GetEntry error: %w", err) + } + + if isSecret { + value, err = dpapi.Decrypt(value) + if err != nil { + return nil, fmt.Errorf("GetEntry error: %w", err) + } + } + + return &RegistryEntry{ + Path: path, + Key: key, + Value: value, + IsSecret: isSecret, + }, nil +} + +// CreateEntry creates a new registry entry +func CreateEntry(entry *RegistryEntry) error { + baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, entry.Path, registry.SET_VALUE) + if err != nil { + // If the key doesn't exist, create it + baseKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, entry.Path, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("CreateEntry error creating key: %w", err) + } + } + defer baseKey.Close() + + value := entry.Value + if entry.IsSecret { + encrypted, err := dpapi.Encrypt(value) + if err != nil { + return fmt.Errorf("CreateEntry error encrypting: %w", err) + } + value = encrypted + } + + err = baseKey.SetStringValue(entry.Key, value) + if err != nil { + return fmt.Errorf("CreateEntry error setting value: %w", err) + } + + return nil +} + +// UpdateEntry updates an existing registry entry +func UpdateEntry(entry *RegistryEntry) error { + // First check if the entry exists + _, err := GetEntry(entry.Path, entry.Key, entry.IsSecret) + if err != nil { + return fmt.Errorf("UpdateEntry error: entry does not exist: %w", err) + } + + // Reuse CreateEntry logic for the update + return CreateEntry(entry) +} + +// DeleteEntry deletes a registry entry +func DeleteEntry(path string, key string) error { + baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.SET_VALUE) + if err != nil { + return fmt.Errorf("DeleteEntry error opening key: %w", err) + } + defer baseKey.Close() + + err = baseKey.DeleteValue(key) + if err != nil { + return fmt.Errorf("DeleteEntry error deleting value: %w", err) + } + + return nil +} + +// DeleteKey deletes an entire registry key and all its values +func DeleteKey(path string) error { + err := registry.DeleteKey(registry.LOCAL_MACHINE, path) + if err != nil { + return fmt.Errorf("DeleteKey error: %w", err) + } + return nil +} + +// ListEntries lists all values in a registry key +func ListEntries(path string) ([]string, error) { + baseKey, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + return nil, fmt.Errorf("ListEntries error opening key: %w", err) + } + defer baseKey.Close() + + valueNames, err := baseKey.ReadValueNames(0) + if err != nil { + return nil, fmt.Errorf("ListEntries error reading values: %w", err) + } + + return valueNames, nil +} diff --git a/internal/agent/snapshots/windows.go b/internal/agent/snapshots/windows.go index 93ab623..903ce42 100644 --- a/internal/agent/snapshots/windows.go +++ b/internal/agent/snapshots/windows.go @@ -1,154 +1,154 @@ -//go:build windows -// +build windows - -package snapshots - -import ( - "context" - "errors" - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "sync/atomic" - "time" - - "github.com/mxk/go-vss" -) - -var ( - ErrSnapshotTimeout = errors.New("timeout waiting for in-progress snapshot") - ErrSnapshotCreation = errors.New("failed to create snapshot") - ErrInvalidSnapshot = errors.New("invalid snapshot") -) - -type WinVSSSnapshot struct { - SnapshotPath string `json:"path"` - Id string `json:"vss_id"` - TimeStarted time.Time `json:"time_started"` - DriveLetter string - closed atomic.Bool -} - -func getVSSFolder() (string, error) { - tmpDir := os.TempDir() - configBasePath := filepath.Join(tmpDir, "pbs-plus-vss") - if err := os.MkdirAll(configBasePath, 0750); err != nil { - return "", fmt.Errorf("failed to create VSS directory %q: %w", configBasePath, err) - } - return configBasePath, nil -} - -// Snapshot creates a new VSS snapshot for the specified drive -func Snapshot(driveLetter string) (*WinVSSSnapshot, error) { - volName := filepath.VolumeName(fmt.Sprintf("%s:", driveLetter)) - vssFolder, err := getVSSFolder() - if err != nil { - return nil, fmt.Errorf("error getting VSS folder: %w", err) - } - - snapshotPath := filepath.Join(vssFolder, driveLetter) - timeStarted := time.Now() - - cleanupExistingSnapshot(snapshotPath) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - if err := createSnapshotWithRetry(ctx, snapshotPath, volName); err != nil { - cleanupExistingSnapshot(snapshotPath) - return nil, fmt.Errorf("snapshot creation failed: %w", err) - } - - sc, err := vss.Get(snapshotPath) - if err != nil { - cleanupExistingSnapshot(snapshotPath) - return nil, fmt.Errorf("snapshot validation failed: %w", err) - } - - snapshot := &WinVSSSnapshot{ - SnapshotPath: snapshotPath, - Id: sc.ID, - TimeStarted: timeStarted, - DriveLetter: driveLetter, - } - - return snapshot, nil -} - -// reregisterVSSWriters attempts to restart VSS services when needed -func reregisterVSSWriters() error { - services := []string{ - "Winmgmt", // Windows Management Instrumentation - "VSS", // Volume Shadow Copy - "swprv", // Microsoft Software Shadow Copy Provider - } - - for _, svc := range services { - if err := exec.Command("net", "stop", svc).Run(); err != nil { - return fmt.Errorf("failed to stop service %s: %w", svc, err) - } - } - - for i := len(services) - 1; i >= 0; i-- { - if err := exec.Command("net", "start", services[i]).Run(); err != nil { - return fmt.Errorf("failed to start service %s: %w", services[i], err) - } - } - - return nil -} - -func createSnapshotWithRetry(ctx context.Context, snapshotPath, volName string) error { - const retryInterval = time.Second - var lastError error - - for attempts := 0; attempts < 2; attempts++ { - for { - if err := vss.CreateLink(snapshotPath, volName); err == nil { - return nil - } else if !strings.Contains(err.Error(), "shadow copy operation is already in progress") { - lastError = err - // If this is our first attempt and it's a VSS-related error, - // try re-registering writers - if attempts == 0 && (strings.Contains(err.Error(), "VSS") || - strings.Contains(err.Error(), "shadow copy")) { - fmt.Println("VSS error detected, attempting to re-register writers...") - if reregErr := reregisterVSSWriters(); reregErr != nil { - fmt.Printf("Warning: failed to re-register VSS writers: %v\n", reregErr) - } - // Break inner loop to start fresh after re-registration - break - } - return fmt.Errorf("%w: %v", ErrSnapshotCreation, err) - } - - select { - case <-ctx.Done(): - return ErrSnapshotTimeout - case <-time.After(retryInterval): - continue - } - } - } - - return fmt.Errorf("%w: %v", ErrSnapshotCreation, lastError) -} - -func cleanupExistingSnapshot(path string) { - if sc, err := vss.Get(path); err == nil { - _ = vss.Remove(sc.ID) - } - - _ = os.Remove(path) -} - -func (s *WinVSSSnapshot) Close() { - if s == nil || !s.closed.CompareAndSwap(false, true) { - return - } - - _ = vss.Remove(s.Id) - _ = os.Remove(s.SnapshotPath) -} +//go:build windows +// +build windows + +package snapshots + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "github.com/mxk/go-vss" +) + +var ( + ErrSnapshotTimeout = errors.New("timeout waiting for in-progress snapshot") + ErrSnapshotCreation = errors.New("failed to create snapshot") + ErrInvalidSnapshot = errors.New("invalid snapshot") +) + +type WinVSSSnapshot struct { + SnapshotPath string `json:"path"` + Id string `json:"vss_id"` + TimeStarted time.Time `json:"time_started"` + DriveLetter string + closed atomic.Bool +} + +func getVSSFolder() (string, error) { + tmpDir := os.TempDir() + configBasePath := filepath.Join(tmpDir, "pbs-plus-vss") + if err := os.MkdirAll(configBasePath, 0750); err != nil { + return "", fmt.Errorf("failed to create VSS directory %q: %w", configBasePath, err) + } + return configBasePath, nil +} + +// Snapshot creates a new VSS snapshot for the specified drive +func Snapshot(driveLetter string) (*WinVSSSnapshot, error) { + volName := filepath.VolumeName(fmt.Sprintf("%s:", driveLetter)) + vssFolder, err := getVSSFolder() + if err != nil { + return nil, fmt.Errorf("error getting VSS folder: %w", err) + } + + snapshotPath := filepath.Join(vssFolder, driveLetter) + timeStarted := time.Now() + + cleanupExistingSnapshot(snapshotPath) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + if err := createSnapshotWithRetry(ctx, snapshotPath, volName); err != nil { + cleanupExistingSnapshot(snapshotPath) + return nil, fmt.Errorf("snapshot creation failed: %w", err) + } + + sc, err := vss.Get(snapshotPath) + if err != nil { + cleanupExistingSnapshot(snapshotPath) + return nil, fmt.Errorf("snapshot validation failed: %w", err) + } + + snapshot := &WinVSSSnapshot{ + SnapshotPath: snapshotPath, + Id: sc.ID, + TimeStarted: timeStarted, + DriveLetter: driveLetter, + } + + return snapshot, nil +} + +// reregisterVSSWriters attempts to restart VSS services when needed +func reregisterVSSWriters() error { + services := []string{ + "Winmgmt", // Windows Management Instrumentation + "VSS", // Volume Shadow Copy + "swprv", // Microsoft Software Shadow Copy Provider + } + + for _, svc := range services { + if err := exec.Command("net", "stop", svc).Run(); err != nil { + return fmt.Errorf("failed to stop service %s: %w", svc, err) + } + } + + for i := len(services) - 1; i >= 0; i-- { + if err := exec.Command("net", "start", services[i]).Run(); err != nil { + return fmt.Errorf("failed to start service %s: %w", services[i], err) + } + } + + return nil +} + +func createSnapshotWithRetry(ctx context.Context, snapshotPath, volName string) error { + const retryInterval = time.Second + var lastError error + + for attempts := 0; attempts < 2; attempts++ { + for { + if err := vss.CreateLink(snapshotPath, volName); err == nil { + return nil + } else if !strings.Contains(err.Error(), "shadow copy operation is already in progress") { + lastError = err + // If this is our first attempt and it's a VSS-related error, + // try re-registering writers + if attempts == 0 && (strings.Contains(err.Error(), "VSS") || + strings.Contains(err.Error(), "shadow copy")) { + fmt.Println("VSS error detected, attempting to re-register writers...") + if reregErr := reregisterVSSWriters(); reregErr != nil { + fmt.Printf("Warning: failed to re-register VSS writers: %v\n", reregErr) + } + // Break inner loop to start fresh after re-registration + break + } + return fmt.Errorf("%w: %v", ErrSnapshotCreation, err) + } + + select { + case <-ctx.Done(): + return ErrSnapshotTimeout + case <-time.After(retryInterval): + continue + } + } + } + + return fmt.Errorf("%w: %v", ErrSnapshotCreation, lastError) +} + +func cleanupExistingSnapshot(path string) { + if sc, err := vss.Get(path); err == nil { + _ = vss.Remove(sc.ID) + } + + _ = os.Remove(path) +} + +func (s *WinVSSSnapshot) Close() { + if s == nil || !s.closed.CompareAndSwap(false, true) { + return + } + + _ = vss.Remove(s.Id) + _ = os.Remove(s.SnapshotPath) +} diff --git a/internal/agent/systray_comm.go b/internal/agent/systray_comm.go index 1ff1924..6dc3aeb 100644 --- a/internal/agent/systray_comm.go +++ b/internal/agent/systray_comm.go @@ -1,28 +1,28 @@ -//go:build windows - -package agent - -import ( - "golang.org/x/sys/windows/registry" -) - -func SetStatus(status string) error { - key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, `Software\PBSPlus`, registry.ALL_ACCESS) - if err == nil { - defer key.Close() - err := key.SetStringValue("Status", status) - return err - } - return err -} - -func GetStatus() (string, error) { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\PBSPlus`, registry.QUERY_VALUE) - if err == nil { - defer key.Close() - regStatus, _, err := key.GetStringValue("Status") - return regStatus, err - } - - return "", err -} +//go:build windows + +package agent + +import ( + "golang.org/x/sys/windows/registry" +) + +func SetStatus(status string) error { + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, `Software\PBSPlus`, registry.ALL_ACCESS) + if err == nil { + defer key.Close() + err := key.SetStringValue("Status", status) + return err + } + return err +} + +func GetStatus() (string, error) { + key, err := registry.OpenKey(registry.LOCAL_MACHINE, `Software\PBSPlus`, registry.QUERY_VALUE) + if err == nil { + defer key.Close() + regStatus, _, err := key.GetStringValue("Status") + return regStatus, err + } + + return "", err +} diff --git a/internal/agent/tls_config.go b/internal/agent/tls_config.go index 7817787..7a60462 100644 --- a/internal/agent/tls_config.go +++ b/internal/agent/tls_config.go @@ -1,174 +1,174 @@ -//go:build windows - -package agent - -import ( - "bytes" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/json" - "encoding/pem" - "fmt" - "net/http" - "os" - "time" - - "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" - "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -func GetTLSConfig() (*tls.Config, error) { - serverCertReg, err := registry.GetEntry(registry.AUTH, "ServerCA", true) - if err != nil { - return nil, fmt.Errorf("GetTLSConfig: server cert not found -> %w", err) - } - - rootCAs := x509.NewCertPool() - if ok := rootCAs.AppendCertsFromPEM([]byte(serverCertReg.Value)); !ok { - return nil, fmt.Errorf("failed to append CA certificate: %s", serverCertReg.Value) - } - - certReg, err := registry.GetEntry(registry.AUTH, "Cert", true) - if err != nil { - return nil, fmt.Errorf("GetTLSConfig: cert not found -> %w", err) - } - - keyReg, err := registry.GetEntry(registry.AUTH, "Priv", true) - if err != nil { - return nil, fmt.Errorf("GetTLSConfig: key not found -> %w", err) - } - - certPEM := []byte(certReg.Value) - keyPEM := []byte(keyReg.Value) - - // Configure TLS client - cert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - return nil, fmt.Errorf("failed to load client certificate: %w\n%v\n%v", err, certPEM, keyPEM) - } - - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: rootCAs, - }, nil -} - -func CheckAndRenewCertificate() error { - const renewalWindow = 30 * 24 * time.Hour // Renew if certificate expires in less than 30 days - - certReg, err := registry.GetEntry(registry.AUTH, "Cert", true) - if err != nil { - return fmt.Errorf("CheckAndRenewCertificate: failed to retrieve certificate - %w", err) - } - - block, _ := pem.Decode([]byte(certReg.Value)) - if block == nil { - return fmt.Errorf("CheckAndRenewCertificate: failed to decode PEM block") - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return fmt.Errorf("CheckAndRenewCertificate: failed to parse certificate - %w", err) - } - - now := time.Now() - timeUntilExpiry := cert.NotAfter.Sub(now) - - switch { - case cert.NotAfter.Before(now): - _ = registry.DeleteEntry(registry.AUTH, "Cert") - _ = registry.DeleteEntry(registry.AUTH, "Priv") - - return fmt.Errorf("Certificate has expired. This agent needs to be bootstrapped again.") - case timeUntilExpiry < renewalWindow: - fmt.Printf("Certificate expires in %v hours. Renewing...\n", timeUntilExpiry.Hours()) - return renewCertificate() - default: - fmt.Printf("Certificate valid for %v days. No renewal needed.\n", timeUntilExpiry.Hours()/24) - return nil - } -} - -func renewCertificate() error { - hostname, _ := os.Hostname() - - csr, privKey, err := certificates.GenerateCSR(hostname, 2048) - if err != nil { - return fmt.Errorf("Bootstrap: generating csr failed -> %w", err) - } - - encodedCSR := base64.StdEncoding.EncodeToString(csr) - - drives, err := utils.GetLocalDrives() - if err != nil { - return fmt.Errorf("Bootstrap: failed to get local drives list: %w", err) - } - - reqBody, err := json.Marshal(&BootstrapRequest{ - Hostname: hostname, - Drives: drives, - CSR: encodedCSR, - }) - if err != nil { - return fmt.Errorf("failed to marshal bootstrap request: %w", err) - } - - renewResp := &BootstrapResponse{} - - _, err = ProxmoxHTTPRequest(http.MethodPost, "/plus/agent/renew", bytes.NewBuffer(reqBody), &renewResp) - if err != nil { - return fmt.Errorf("failed to fetch renewed certificate: %w", err) - } - - decodedCA, err := base64.StdEncoding.DecodeString(renewResp.CA) - if err != nil { - return fmt.Errorf("Renew: error decoding ca content (%s) -> %w", string(renewResp.CA), err) - } - - decodedCert, err := base64.StdEncoding.DecodeString(renewResp.Cert) - if err != nil { - return fmt.Errorf("Renew: error decoding cert content (%s) -> %w", string(renewResp.Cert), err) - } - - privKeyPEM := certificates.EncodeKeyPEM(privKey) - - caEntry := registry.RegistryEntry{ - Key: "ServerCA", - Value: string(decodedCA), - Path: registry.AUTH, - IsSecret: true, - } - - certEntry := registry.RegistryEntry{ - Key: "Cert", - Value: string(decodedCert), - Path: registry.AUTH, - IsSecret: true, - } - - privEntry := registry.RegistryEntry{ - Key: "Priv", - Value: string(privKeyPEM), - Path: registry.AUTH, - IsSecret: true, - } - - err = registry.CreateEntry(&caEntry) - if err != nil { - return fmt.Errorf("Renew: error storing ca to registry -> %w", err) - } - - err = registry.CreateEntry(&certEntry) - if err != nil { - return fmt.Errorf("Renew: error storing cert to registry -> %w", err) - } - - err = registry.CreateEntry(&privEntry) - if err != nil { - return fmt.Errorf("Renew: error storing priv to registry -> %w", err) - } - - return nil -} +//go:build windows + +package agent + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "os" + "time" + + "github.com/sonroyaalmerol/pbs-plus/internal/agent/registry" + "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +func GetTLSConfig() (*tls.Config, error) { + serverCertReg, err := registry.GetEntry(registry.AUTH, "ServerCA", true) + if err != nil { + return nil, fmt.Errorf("GetTLSConfig: server cert not found -> %w", err) + } + + rootCAs := x509.NewCertPool() + if ok := rootCAs.AppendCertsFromPEM([]byte(serverCertReg.Value)); !ok { + return nil, fmt.Errorf("failed to append CA certificate: %s", serverCertReg.Value) + } + + certReg, err := registry.GetEntry(registry.AUTH, "Cert", true) + if err != nil { + return nil, fmt.Errorf("GetTLSConfig: cert not found -> %w", err) + } + + keyReg, err := registry.GetEntry(registry.AUTH, "Priv", true) + if err != nil { + return nil, fmt.Errorf("GetTLSConfig: key not found -> %w", err) + } + + certPEM := []byte(certReg.Value) + keyPEM := []byte(keyReg.Value) + + // Configure TLS client + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate: %w\n%v\n%v", err, certPEM, keyPEM) + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: rootCAs, + }, nil +} + +func CheckAndRenewCertificate() error { + const renewalWindow = 30 * 24 * time.Hour // Renew if certificate expires in less than 30 days + + certReg, err := registry.GetEntry(registry.AUTH, "Cert", true) + if err != nil { + return fmt.Errorf("CheckAndRenewCertificate: failed to retrieve certificate - %w", err) + } + + block, _ := pem.Decode([]byte(certReg.Value)) + if block == nil { + return fmt.Errorf("CheckAndRenewCertificate: failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("CheckAndRenewCertificate: failed to parse certificate - %w", err) + } + + now := time.Now() + timeUntilExpiry := cert.NotAfter.Sub(now) + + switch { + case cert.NotAfter.Before(now): + _ = registry.DeleteEntry(registry.AUTH, "Cert") + _ = registry.DeleteEntry(registry.AUTH, "Priv") + + return fmt.Errorf("Certificate has expired. This agent needs to be bootstrapped again.") + case timeUntilExpiry < renewalWindow: + fmt.Printf("Certificate expires in %v hours. Renewing...\n", timeUntilExpiry.Hours()) + return renewCertificate() + default: + fmt.Printf("Certificate valid for %v days. No renewal needed.\n", timeUntilExpiry.Hours()/24) + return nil + } +} + +func renewCertificate() error { + hostname, _ := os.Hostname() + + csr, privKey, err := certificates.GenerateCSR(hostname, 2048) + if err != nil { + return fmt.Errorf("Bootstrap: generating csr failed -> %w", err) + } + + encodedCSR := base64.StdEncoding.EncodeToString(csr) + + drives, err := utils.GetLocalDrives() + if err != nil { + return fmt.Errorf("Bootstrap: failed to get local drives list: %w", err) + } + + reqBody, err := json.Marshal(&BootstrapRequest{ + Hostname: hostname, + Drives: drives, + CSR: encodedCSR, + }) + if err != nil { + return fmt.Errorf("failed to marshal bootstrap request: %w", err) + } + + renewResp := &BootstrapResponse{} + + _, err = ProxmoxHTTPRequest(http.MethodPost, "/plus/agent/renew", bytes.NewBuffer(reqBody), &renewResp) + if err != nil { + return fmt.Errorf("failed to fetch renewed certificate: %w", err) + } + + decodedCA, err := base64.StdEncoding.DecodeString(renewResp.CA) + if err != nil { + return fmt.Errorf("Renew: error decoding ca content (%s) -> %w", string(renewResp.CA), err) + } + + decodedCert, err := base64.StdEncoding.DecodeString(renewResp.Cert) + if err != nil { + return fmt.Errorf("Renew: error decoding cert content (%s) -> %w", string(renewResp.Cert), err) + } + + privKeyPEM := certificates.EncodeKeyPEM(privKey) + + caEntry := registry.RegistryEntry{ + Key: "ServerCA", + Value: string(decodedCA), + Path: registry.AUTH, + IsSecret: true, + } + + certEntry := registry.RegistryEntry{ + Key: "Cert", + Value: string(decodedCert), + Path: registry.AUTH, + IsSecret: true, + } + + privEntry := registry.RegistryEntry{ + Key: "Priv", + Value: string(privKeyPEM), + Path: registry.AUTH, + IsSecret: true, + } + + err = registry.CreateEntry(&caEntry) + if err != nil { + return fmt.Errorf("Renew: error storing ca to registry -> %w", err) + } + + err = registry.CreateEntry(&certEntry) + if err != nil { + return fmt.Errorf("Renew: error storing cert to registry -> %w", err) + } + + err = registry.CreateEntry(&privEntry) + if err != nil { + return fmt.Errorf("Renew: error storing priv to registry -> %w", err) + } + + return nil +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 94264c9..897ec6f 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -1,245 +1,245 @@ -//go:build linux - -package auth - -import ( - "bytes" - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "net" - "net/http" - "os" - "path/filepath" - "strings" - "sync" - "testing" - "time" - - "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" - serverLib "github.com/sonroyaalmerol/pbs-plus/internal/auth/server" - "github.com/sonroyaalmerol/pbs-plus/internal/auth/testhelpers" -) - -func TestEndToEnd(t *testing.T) { - // Create temporary directory for test certificates - certsDir, err := os.MkdirTemp("", "auth-test-*") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(certsDir) - - // Generate certificates - certOpts := certificates.DefaultOptions() - certOpts.OutputDir = certsDir - certOpts.Organization = "Test Auth System" - certOpts.ValidDays = 1 - certOpts.Hostnames = []string{"localhost", "test.local"} - certOpts.IPs = []net.IP{net.ParseIP("127.0.0.1")} - - generator, err := certificates.NewGenerator(certOpts) - if err != nil { - t.Fatal(err) - } - - if err := generator.GenerateAll(); err != nil { - t.Fatal(err) - } - - // Start server - serverConfig := serverLib.DefaultConfig() - serverConfig.Address = ":44443" // Different port for testing - serverConfig.CertFile = filepath.Join(certsDir, "server.crt") - serverConfig.KeyFile = filepath.Join(certsDir, "server.key") - serverConfig.CAFile = filepath.Join(certsDir, "ca.crt") - serverConfig.TokenExpiration = 1 * time.Hour - - srv, err := testhelpers.NewServer(serverConfig) - if err != nil { - t.Fatal(err) - } - - // Start server in goroutine - serverErrCh := make(chan error, 1) - go func() { - if err := srv.Start(); err != nil { - if !isClosedConnError(err) { - serverErrCh <- err - } - } - }() - - // Give server time to start - time.Sleep(100 * time.Millisecond) - - defer func() { - // Graceful shutdown - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := srv.Stop(shutdownCtx); err != nil && !isClosedConnError(err) { - t.Error("shutdown error:", err) - } - - // Check if there were any server errors - select { - case err := <-serverErrCh: - if err != nil { - if !strings.Contains(err.Error(), "Server closed") { - t.Error("server error:", err) - } - } - default: - } - }() - - // Create and start multiple agents - agents := make([]*testhelpers.Agent, 3) - for i := 0; i < 3; i++ { - agentConfig := testhelpers.DefaultAgentConfig() - agentConfig.AgentID = fmt.Sprintf("test-agent-%d", i) - agentConfig.ServerURL = "https://localhost:44443" - agentConfig.CertFile = filepath.Join(certsDir, "agent.crt") - agentConfig.KeyFile = filepath.Join(certsDir, "agent.key") - agentConfig.CAFile = filepath.Join(certsDir, "ca.crt") - agentConfig.Timeout = 5 * time.Second - agentConfig.MaxRetries = 3 - agentConfig.RetryInterval = 100 * time.Millisecond - agentConfig.KeepAlive = true - agentConfig.KeepAliveInterval = 500 * time.Millisecond - - a, err := testhelpers.NewAgent(agentConfig) - if err != nil { - t.Fatal(err) - } - agents[i] = a - } - - // Start all agents - ctx := context.Background() - for i, a := range agents { - if err := a.Start(ctx); err != nil { - t.Fatalf("Failed to start agent %d: %v", i, err) - } - defer a.Stop() - } - - // Test parallel requests from all agents - t.Run("ParallelRequests", func(t *testing.T) { - var wg sync.WaitGroup - errors := make(chan error, len(agents)*3) // 3 requests per agent - - for _, a := range agents { - wg.Add(1) - go func(agent *testhelpers.Agent) { - defer wg.Done() - for i := 0; i < 3; i++ { - resp, err := agent.SendRequest(ctx, fmt.Sprintf("test message %d", i)) - if err != nil { - errors <- err - return - } - if resp == nil || resp.Message == "" { - errors <- fmt.Errorf("empty response") - return - } - time.Sleep(100 * time.Millisecond) - } - }(a) - } - - wg.Wait() - close(errors) - - for err := range errors { - if err != nil { - t.Error(err) - } - } - }) - - // Test token expiration and renewal - t.Run("TokenRenewal", func(t *testing.T) { - // Get initial token - initialToken := agents[0].GetToken() - if initialToken == "" { - t.Fatal("Expected non-empty initial token") - } - - // Wait for a keepalive cycle - time.Sleep(600 * time.Millisecond) - - // Send another request - resp, err := agents[0].SendRequest(ctx, "test token renewal") - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("Expected non-nil response") - } - }) - - // Test invalid requests - t.Run("InvalidRequests", func(t *testing.T) { - // Load client certificate for invalid token test - cert, err := tls.LoadX509KeyPair( - filepath.Join(certsDir, "agent.crt"), - filepath.Join(certsDir, "agent.key"), - ) - if err != nil { - t.Fatal(err) - } - - // Load CA cert - caCert, err := os.ReadFile(filepath.Join(certsDir, "ca.crt")) - if err != nil { - t.Fatal(err) - } - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM(caCert) { - t.Fatal("Failed to append CA cert") - } - - // Create client with valid certificates but invalid token - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: caCertPool, - }, - }, - } - - reqBody := []byte(`{"agent_id": "test-invalid", "data": "test"}`) - req, err := http.NewRequest("POST", - "https://localhost:44443/secure", - bytes.NewBuffer(reqBody)) - if err != nil { - t.Fatal(err) - } - - req.Header.Set("Authorization", "invalid-token") - req.Header.Set("Content-Type", "application/json") - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusUnauthorized { - t.Errorf("Expected status unauthorized, got %v", resp.Status) - } - }) - - // End of test -} - -// Helper function to check for "use of closed network connection" error -func isClosedConnError(err error) bool { - if err == nil { - return false - } - return strings.Contains(err.Error(), "use of closed network connection") -} +//go:build linux + +package auth + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" + serverLib "github.com/sonroyaalmerol/pbs-plus/internal/auth/server" + "github.com/sonroyaalmerol/pbs-plus/internal/auth/testhelpers" +) + +func TestEndToEnd(t *testing.T) { + // Create temporary directory for test certificates + certsDir, err := os.MkdirTemp("", "auth-test-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(certsDir) + + // Generate certificates + certOpts := certificates.DefaultOptions() + certOpts.OutputDir = certsDir + certOpts.Organization = "Test Auth System" + certOpts.ValidDays = 1 + certOpts.Hostnames = []string{"localhost", "test.local"} + certOpts.IPs = []net.IP{net.ParseIP("127.0.0.1")} + + generator, err := certificates.NewGenerator(certOpts) + if err != nil { + t.Fatal(err) + } + + if err := generator.GenerateAll(); err != nil { + t.Fatal(err) + } + + // Start server + serverConfig := serverLib.DefaultConfig() + serverConfig.Address = ":44443" // Different port for testing + serverConfig.CertFile = filepath.Join(certsDir, "server.crt") + serverConfig.KeyFile = filepath.Join(certsDir, "server.key") + serverConfig.CAFile = filepath.Join(certsDir, "ca.crt") + serverConfig.TokenExpiration = 1 * time.Hour + + srv, err := testhelpers.NewServer(serverConfig) + if err != nil { + t.Fatal(err) + } + + // Start server in goroutine + serverErrCh := make(chan error, 1) + go func() { + if err := srv.Start(); err != nil { + if !isClosedConnError(err) { + serverErrCh <- err + } + } + }() + + // Give server time to start + time.Sleep(100 * time.Millisecond) + + defer func() { + // Graceful shutdown + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := srv.Stop(shutdownCtx); err != nil && !isClosedConnError(err) { + t.Error("shutdown error:", err) + } + + // Check if there were any server errors + select { + case err := <-serverErrCh: + if err != nil { + if !strings.Contains(err.Error(), "Server closed") { + t.Error("server error:", err) + } + } + default: + } + }() + + // Create and start multiple agents + agents := make([]*testhelpers.Agent, 3) + for i := 0; i < 3; i++ { + agentConfig := testhelpers.DefaultAgentConfig() + agentConfig.AgentID = fmt.Sprintf("test-agent-%d", i) + agentConfig.ServerURL = "https://localhost:44443" + agentConfig.CertFile = filepath.Join(certsDir, "agent.crt") + agentConfig.KeyFile = filepath.Join(certsDir, "agent.key") + agentConfig.CAFile = filepath.Join(certsDir, "ca.crt") + agentConfig.Timeout = 5 * time.Second + agentConfig.MaxRetries = 3 + agentConfig.RetryInterval = 100 * time.Millisecond + agentConfig.KeepAlive = true + agentConfig.KeepAliveInterval = 500 * time.Millisecond + + a, err := testhelpers.NewAgent(agentConfig) + if err != nil { + t.Fatal(err) + } + agents[i] = a + } + + // Start all agents + ctx := context.Background() + for i, a := range agents { + if err := a.Start(ctx); err != nil { + t.Fatalf("Failed to start agent %d: %v", i, err) + } + defer a.Stop() + } + + // Test parallel requests from all agents + t.Run("ParallelRequests", func(t *testing.T) { + var wg sync.WaitGroup + errors := make(chan error, len(agents)*3) // 3 requests per agent + + for _, a := range agents { + wg.Add(1) + go func(agent *testhelpers.Agent) { + defer wg.Done() + for i := 0; i < 3; i++ { + resp, err := agent.SendRequest(ctx, fmt.Sprintf("test message %d", i)) + if err != nil { + errors <- err + return + } + if resp == nil || resp.Message == "" { + errors <- fmt.Errorf("empty response") + return + } + time.Sleep(100 * time.Millisecond) + } + }(a) + } + + wg.Wait() + close(errors) + + for err := range errors { + if err != nil { + t.Error(err) + } + } + }) + + // Test token expiration and renewal + t.Run("TokenRenewal", func(t *testing.T) { + // Get initial token + initialToken := agents[0].GetToken() + if initialToken == "" { + t.Fatal("Expected non-empty initial token") + } + + // Wait for a keepalive cycle + time.Sleep(600 * time.Millisecond) + + // Send another request + resp, err := agents[0].SendRequest(ctx, "test token renewal") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("Expected non-nil response") + } + }) + + // Test invalid requests + t.Run("InvalidRequests", func(t *testing.T) { + // Load client certificate for invalid token test + cert, err := tls.LoadX509KeyPair( + filepath.Join(certsDir, "agent.crt"), + filepath.Join(certsDir, "agent.key"), + ) + if err != nil { + t.Fatal(err) + } + + // Load CA cert + caCert, err := os.ReadFile(filepath.Join(certsDir, "ca.crt")) + if err != nil { + t.Fatal(err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + t.Fatal("Failed to append CA cert") + } + + // Create client with valid certificates but invalid token + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + }, + }, + } + + reqBody := []byte(`{"agent_id": "test-invalid", "data": "test"}`) + req, err := http.NewRequest("POST", + "https://localhost:44443/secure", + bytes.NewBuffer(reqBody)) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("Authorization", "invalid-token") + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("Expected status unauthorized, got %v", resp.Status) + } + }) + + // End of test +} + +// Helper function to check for "use of closed network connection" error +func isClosedConnError(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "use of closed network connection") +} diff --git a/internal/auth/certificates/generator.go b/internal/auth/certificates/generator.go index 7cf8f53..929af08 100644 --- a/internal/auth/certificates/generator.go +++ b/internal/auth/certificates/generator.go @@ -1,444 +1,444 @@ -package certificates - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "errors" - "fmt" - "math" - "math/big" - "net" - "os" - "path/filepath" - "time" - - authErrors "github.com/sonroyaalmerol/pbs-plus/internal/auth/errors" -) - -// Options represents configuration for certificate generation -type Options struct { - // Organization name for the CA certificate - Organization string - // Common name for the CA certificate - CommonName string - // Valid duration for certificates - ValidDays int - // Key size in bits (e.g., 2048, 4096) - KeySize int - // Output directory for certificates - OutputDir string - // Hostnames to include in SAN - Hostnames []string - // IP addresses to include in SAN - IPs []net.IP -} - -// DefaultOptions returns default certificate generation options -func DefaultOptions() *Options { - // Get all non-loopback interfaces - interfaces, err := net.Interfaces() - hostnames := []string{"localhost"} - ips := []net.IP{net.ParseIP("127.0.0.1")} - - if err == nil { - for _, i := range interfaces { - // Skip loopback - if i.Flags&net.FlagLoopback != 0 { - continue - } - addrs, err := i.Addrs() - if err != nil { - continue - } - for _, addr := range addrs { - switch v := addr.(type) { - case *net.IPNet: - if ip4 := v.IP.To4(); ip4 != nil { - ips = append(ips, ip4) - } - } - } - } - } - - // Try to get hostname - if hostname, err := os.Hostname(); err == nil { - hostnames = append(hostnames, hostname) - } - - return &Options{ - Organization: "PBS Plus", - CommonName: "PBS Plus CA", - ValidDays: 365, - KeySize: 2048, - OutputDir: "/etc/proxmox-backup/pbs-plus/certs", - Hostnames: hostnames, - IPs: ips, - } -} - -// Generator handles certificate generation -type Generator struct { - options *Options - ca *x509.Certificate - caKey *rsa.PrivateKey -} - -// NewGenerator creates a new certificate generator -func NewGenerator(options *Options) (*Generator, error) { - if options == nil { - options = DefaultOptions() - } - - // Create output directory if it doesn't exist - if err := os.MkdirAll(options.OutputDir, 0755); err != nil { - return nil, authErrors.WrapError("create_output_dir", err) - } - - return &Generator{ - options: options, - }, nil -} - -func (g *Generator) GetCAPEM() []byte { - return EncodeCertPEM(g.ca.Raw) -} - -// GenerateCA generates a new CA certificate and private key -func (g *Generator) GenerateCA() error { - key, err := rsa.GenerateKey(rand.Reader, g.options.KeySize) - if err != nil { - return authErrors.WrapError("generate_ca_key", err) - } - - ca := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{g.options.Organization}, - CommonName: g.options.CommonName, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, g.options.ValidDays), - IsCA: true, - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, - BasicConstraintsValid: true, - } - - // Self-sign the CA certificate - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &key.PublicKey, key) - if err != nil { - return authErrors.WrapError("create_ca_cert", err) - } - - // Save CA certificate - if err := g.saveCertificate("ca.crt", caBytes); err != nil { - return err - } - - // Save CA private key - if err := g.savePrivateKey("ca.key", key); err != nil { - return err - } - - g.ca = ca - g.caKey = key - return nil -} - -// GenerateCert generates a new certificate signed by the CA -func (g *Generator) GenerateCert(name string) error { - if g.ca == nil || g.caKey == nil { - return authErrors.WrapError("generate_cert", - errors.New("CA must be generated first")) - } - - key, err := rsa.GenerateKey(rand.Reader, g.options.KeySize) - if err != nil { - return authErrors.WrapError("generate_key", err) - } - - template := &x509.Certificate{ - SerialNumber: big.NewInt(time.Now().UnixNano()), - Subject: pkix.Name{ - Organization: []string{g.options.Organization}, - CommonName: name, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, g.options.ValidDays), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - DNSNames: g.options.Hostnames, - IPAddresses: g.options.IPs, - } - - // Sign the certificate with CA - certBytes, err := x509.CreateCertificate(rand.Reader, template, g.ca, &key.PublicKey, g.caKey) - if err != nil { - return authErrors.WrapError("create_cert", err) - } - - // Save certificate - if err := g.saveCertificate(name+".crt", certBytes); err != nil { - return err - } - - // Save private key - if err := g.savePrivateKey(name+".key", key); err != nil { - return err - } - - return nil -} - -// GenerateAll generates CA and all required certificates -func (g *Generator) GenerateAll() error { - if err := g.GenerateCA(); err != nil { - return err - } - - if err := g.GenerateCert("server"); err != nil { - return err - } - - if err := g.GenerateCert("agent"); err != nil { - return err - } - - return nil -} - -func (g *Generator) saveCertificate(filename string, certBytes []byte) error { - filePath := filepath.Join(g.options.OutputDir, filename) - - certOut, err := os.Create(filePath) - if err != nil { - return authErrors.WrapError("create_cert_file", err) - } - defer certOut.Close() - - if err := pem.Encode(certOut, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }); err != nil { - return authErrors.WrapError("encode_cert", err) - } - - return nil -} - -func (g *Generator) savePrivateKey(filename string, key *rsa.PrivateKey) error { - filePath := filepath.Join(g.options.OutputDir, filename) - keyOut, err := os.OpenFile( - filePath, - os.O_WRONLY|os.O_CREATE|os.O_TRUNC, - 0640, - ) - if err != nil { - return authErrors.WrapError("create_key_file", err) - } - defer keyOut.Close() - - if err := pem.Encode(keyOut, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }); err != nil { - return authErrors.WrapError("encode_key", err) - } - - return nil -} - -func (g *Generator) ValidateExistingCerts() error { - serverCertPath := filepath.Join(g.options.OutputDir, "server.crt") - caPath := filepath.Join(g.options.OutputDir, "ca.crt") - caKeyPath := filepath.Join(g.options.OutputDir, "ca.key") - - // Check if files exist - if _, err := os.Stat(serverCertPath); os.IsNotExist(err) { - return fmt.Errorf("server certificate not found: %s", serverCertPath) - } - if _, err := os.Stat(caPath); os.IsNotExist(err) { - return fmt.Errorf("CA certificate not found: %s", caPath) - } - if _, err := os.Stat(caKeyPath); os.IsNotExist(err) { - return fmt.Errorf("CA certificate not found: %s", caPath) - } - - // Load server certificate - serverCertPEM, err := os.ReadFile(serverCertPath) - if err != nil { - return fmt.Errorf("failed to read server certificate: %w", err) - } - block, _ := pem.Decode(serverCertPEM) - if block == nil { - return fmt.Errorf("failed to parse server certificate PEM") - } - serverCert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return fmt.Errorf("failed to parse server certificate: %w", err) - } - - // Load CA certificate - caPEM, err := os.ReadFile(caPath) - if err != nil { - return fmt.Errorf("failed to read CA certificate: %w", err) - } - block, _ = pem.Decode(caPEM) - if block == nil { - return fmt.Errorf("failed to parse CA certificate PEM") - } - caCert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return fmt.Errorf("failed to parse CA certificate: %w", err) - } - - g.ca = caCert - - caKeyPEM, err := os.ReadFile(caKeyPath) - if err != nil { - return fmt.Errorf("failed to read CA key: %w", err) - } - block, _ = pem.Decode(caKeyPEM) - if block == nil { - return fmt.Errorf("failed to parse CA key PEM") - } - - caKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return fmt.Errorf("failed to parse CA key: %w", err) - } - - g.caKey = caKey - - // Verify server certificate is signed by CA - roots := x509.NewCertPool() - roots.AddCert(caCert) - opts := x509.VerifyOptions{ - Roots: roots, - KeyUsages: []x509.ExtKeyUsage{ - x509.ExtKeyUsageServerAuth, - }, - } - - if _, err := serverCert.Verify(opts); err != nil { - return fmt.Errorf("server certificate validation failed: %w", err) - } - - // Check certificate expiry - now := time.Now() - if now.Before(serverCert.NotBefore) { - return fmt.Errorf("server certificate is not yet valid") - } - if now.After(serverCert.NotAfter) { - return fmt.Errorf("server certificate has expired") - } - if now.Before(caCert.NotBefore) { - return fmt.Errorf("CA certificate is not yet valid") - } - if now.After(caCert.NotAfter) { - return fmt.Errorf("CA certificate has expired") - } - - return nil -} - -func GenerateCSR(commonName string, keySize int) ([]byte, *rsa.PrivateKey, error) { - privKey, err := rsa.GenerateKey(rand.Reader, keySize) - if err != nil { - return nil, nil, fmt.Errorf("failed to generate private key: %w", err) - } - - template := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: commonName, - }, - SignatureAlgorithm: x509.SHA256WithRSA, - } - - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) - if err != nil { - return nil, nil, fmt.Errorf("failed to create CSR: %w", err) - } - - return csrBytes, privKey, nil -} - -func (g *Generator) SignCSR(csr []byte) ([]byte, error) { - if g == nil { - return nil, fmt.Errorf("generator is nil") - } - if g.caKey == nil { - return nil, fmt.Errorf("CA private key is nil") - } - if g.ca == nil { - return nil, fmt.Errorf("CA certificate is nil") - } - if g.options == nil { - return nil, fmt.Errorf("options are nil") - } - - validDays := g.options.ValidDays - if validDays <= 0 { - return nil, fmt.Errorf("invalid validity period: %d days", validDays) - } - - // Parse CSR - csrObj, err := x509.ParseCertificateRequest(csr) - if err != nil { - return nil, fmt.Errorf("failed to parse CSR: %w", err) - } - if err := csrObj.CheckSignature(); err != nil { - return nil, fmt.Errorf("CSR signature check failed: %w", err) - } - - // Validate public key - if csrObj.PublicKey == nil { - return nil, fmt.Errorf("CSR public key is nil") - } - - serialNumber, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) - if err != nil { - return nil, fmt.Errorf("failed to generate serial number: %w", err) - } - - template := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: csrObj.Subject, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 0, validDays), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - } - - certBytes, err := x509.CreateCertificate(rand.Reader, template, g.ca, csrObj.PublicKey, g.caKey) - if err != nil { - return nil, fmt.Errorf("failed to create certificate: %w", err) - } - - return EncodeCertPEM(certBytes), nil -} - -// Helper functions to encode to PEM -func EncodeCertPEM(cert []byte) []byte { - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cert, - }) -} - -func EncodeKeyPEM(key *rsa.PrivateKey) []byte { - return pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) -} - -func EncodeCSRPEM(csr []byte) []byte { - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: csr, - }) -} +package certificates + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "math" + "math/big" + "net" + "os" + "path/filepath" + "time" + + authErrors "github.com/sonroyaalmerol/pbs-plus/internal/auth/errors" +) + +// Options represents configuration for certificate generation +type Options struct { + // Organization name for the CA certificate + Organization string + // Common name for the CA certificate + CommonName string + // Valid duration for certificates + ValidDays int + // Key size in bits (e.g., 2048, 4096) + KeySize int + // Output directory for certificates + OutputDir string + // Hostnames to include in SAN + Hostnames []string + // IP addresses to include in SAN + IPs []net.IP +} + +// DefaultOptions returns default certificate generation options +func DefaultOptions() *Options { + // Get all non-loopback interfaces + interfaces, err := net.Interfaces() + hostnames := []string{"localhost"} + ips := []net.IP{net.ParseIP("127.0.0.1")} + + if err == nil { + for _, i := range interfaces { + // Skip loopback + if i.Flags&net.FlagLoopback != 0 { + continue + } + addrs, err := i.Addrs() + if err != nil { + continue + } + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + if ip4 := v.IP.To4(); ip4 != nil { + ips = append(ips, ip4) + } + } + } + } + } + + // Try to get hostname + if hostname, err := os.Hostname(); err == nil { + hostnames = append(hostnames, hostname) + } + + return &Options{ + Organization: "PBS Plus", + CommonName: "PBS Plus CA", + ValidDays: 365, + KeySize: 2048, + OutputDir: "/etc/proxmox-backup/pbs-plus/certs", + Hostnames: hostnames, + IPs: ips, + } +} + +// Generator handles certificate generation +type Generator struct { + options *Options + ca *x509.Certificate + caKey *rsa.PrivateKey +} + +// NewGenerator creates a new certificate generator +func NewGenerator(options *Options) (*Generator, error) { + if options == nil { + options = DefaultOptions() + } + + // Create output directory if it doesn't exist + if err := os.MkdirAll(options.OutputDir, 0755); err != nil { + return nil, authErrors.WrapError("create_output_dir", err) + } + + return &Generator{ + options: options, + }, nil +} + +func (g *Generator) GetCAPEM() []byte { + return EncodeCertPEM(g.ca.Raw) +} + +// GenerateCA generates a new CA certificate and private key +func (g *Generator) GenerateCA() error { + key, err := rsa.GenerateKey(rand.Reader, g.options.KeySize) + if err != nil { + return authErrors.WrapError("generate_ca_key", err) + } + + ca := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{g.options.Organization}, + CommonName: g.options.CommonName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, g.options.ValidDays), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + + // Self-sign the CA certificate + caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &key.PublicKey, key) + if err != nil { + return authErrors.WrapError("create_ca_cert", err) + } + + // Save CA certificate + if err := g.saveCertificate("ca.crt", caBytes); err != nil { + return err + } + + // Save CA private key + if err := g.savePrivateKey("ca.key", key); err != nil { + return err + } + + g.ca = ca + g.caKey = key + return nil +} + +// GenerateCert generates a new certificate signed by the CA +func (g *Generator) GenerateCert(name string) error { + if g.ca == nil || g.caKey == nil { + return authErrors.WrapError("generate_cert", + errors.New("CA must be generated first")) + } + + key, err := rsa.GenerateKey(rand.Reader, g.options.KeySize) + if err != nil { + return authErrors.WrapError("generate_key", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + Organization: []string{g.options.Organization}, + CommonName: name, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, g.options.ValidDays), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + DNSNames: g.options.Hostnames, + IPAddresses: g.options.IPs, + } + + // Sign the certificate with CA + certBytes, err := x509.CreateCertificate(rand.Reader, template, g.ca, &key.PublicKey, g.caKey) + if err != nil { + return authErrors.WrapError("create_cert", err) + } + + // Save certificate + if err := g.saveCertificate(name+".crt", certBytes); err != nil { + return err + } + + // Save private key + if err := g.savePrivateKey(name+".key", key); err != nil { + return err + } + + return nil +} + +// GenerateAll generates CA and all required certificates +func (g *Generator) GenerateAll() error { + if err := g.GenerateCA(); err != nil { + return err + } + + if err := g.GenerateCert("server"); err != nil { + return err + } + + if err := g.GenerateCert("agent"); err != nil { + return err + } + + return nil +} + +func (g *Generator) saveCertificate(filename string, certBytes []byte) error { + filePath := filepath.Join(g.options.OutputDir, filename) + + certOut, err := os.Create(filePath) + if err != nil { + return authErrors.WrapError("create_cert_file", err) + } + defer certOut.Close() + + if err := pem.Encode(certOut, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }); err != nil { + return authErrors.WrapError("encode_cert", err) + } + + return nil +} + +func (g *Generator) savePrivateKey(filename string, key *rsa.PrivateKey) error { + filePath := filepath.Join(g.options.OutputDir, filename) + keyOut, err := os.OpenFile( + filePath, + os.O_WRONLY|os.O_CREATE|os.O_TRUNC, + 0640, + ) + if err != nil { + return authErrors.WrapError("create_key_file", err) + } + defer keyOut.Close() + + if err := pem.Encode(keyOut, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }); err != nil { + return authErrors.WrapError("encode_key", err) + } + + return nil +} + +func (g *Generator) ValidateExistingCerts() error { + serverCertPath := filepath.Join(g.options.OutputDir, "server.crt") + caPath := filepath.Join(g.options.OutputDir, "ca.crt") + caKeyPath := filepath.Join(g.options.OutputDir, "ca.key") + + // Check if files exist + if _, err := os.Stat(serverCertPath); os.IsNotExist(err) { + return fmt.Errorf("server certificate not found: %s", serverCertPath) + } + if _, err := os.Stat(caPath); os.IsNotExist(err) { + return fmt.Errorf("CA certificate not found: %s", caPath) + } + if _, err := os.Stat(caKeyPath); os.IsNotExist(err) { + return fmt.Errorf("CA certificate not found: %s", caPath) + } + + // Load server certificate + serverCertPEM, err := os.ReadFile(serverCertPath) + if err != nil { + return fmt.Errorf("failed to read server certificate: %w", err) + } + block, _ := pem.Decode(serverCertPEM) + if block == nil { + return fmt.Errorf("failed to parse server certificate PEM") + } + serverCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse server certificate: %w", err) + } + + // Load CA certificate + caPEM, err := os.ReadFile(caPath) + if err != nil { + return fmt.Errorf("failed to read CA certificate: %w", err) + } + block, _ = pem.Decode(caPEM) + if block == nil { + return fmt.Errorf("failed to parse CA certificate PEM") + } + caCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse CA certificate: %w", err) + } + + g.ca = caCert + + caKeyPEM, err := os.ReadFile(caKeyPath) + if err != nil { + return fmt.Errorf("failed to read CA key: %w", err) + } + block, _ = pem.Decode(caKeyPEM) + if block == nil { + return fmt.Errorf("failed to parse CA key PEM") + } + + caKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse CA key: %w", err) + } + + g.caKey = caKey + + // Verify server certificate is signed by CA + roots := x509.NewCertPool() + roots.AddCert(caCert) + opts := x509.VerifyOptions{ + Roots: roots, + KeyUsages: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + } + + if _, err := serverCert.Verify(opts); err != nil { + return fmt.Errorf("server certificate validation failed: %w", err) + } + + // Check certificate expiry + now := time.Now() + if now.Before(serverCert.NotBefore) { + return fmt.Errorf("server certificate is not yet valid") + } + if now.After(serverCert.NotAfter) { + return fmt.Errorf("server certificate has expired") + } + if now.Before(caCert.NotBefore) { + return fmt.Errorf("CA certificate is not yet valid") + } + if now.After(caCert.NotAfter) { + return fmt.Errorf("CA certificate has expired") + } + + return nil +} + +func GenerateCSR(commonName string, keySize int) ([]byte, *rsa.PrivateKey, error) { + privKey, err := rsa.GenerateKey(rand.Reader, keySize) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: commonName, + }, + SignatureAlgorithm: x509.SHA256WithRSA, + } + + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create CSR: %w", err) + } + + return csrBytes, privKey, nil +} + +func (g *Generator) SignCSR(csr []byte) ([]byte, error) { + if g == nil { + return nil, fmt.Errorf("generator is nil") + } + if g.caKey == nil { + return nil, fmt.Errorf("CA private key is nil") + } + if g.ca == nil { + return nil, fmt.Errorf("CA certificate is nil") + } + if g.options == nil { + return nil, fmt.Errorf("options are nil") + } + + validDays := g.options.ValidDays + if validDays <= 0 { + return nil, fmt.Errorf("invalid validity period: %d days", validDays) + } + + // Parse CSR + csrObj, err := x509.ParseCertificateRequest(csr) + if err != nil { + return nil, fmt.Errorf("failed to parse CSR: %w", err) + } + if err := csrObj.CheckSignature(); err != nil { + return nil, fmt.Errorf("CSR signature check failed: %w", err) + } + + // Validate public key + if csrObj.PublicKey == nil { + return nil, fmt.Errorf("CSR public key is nil") + } + + serialNumber, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: csrObj.Subject, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, validDays), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, template, g.ca, csrObj.PublicKey, g.caKey) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %w", err) + } + + return EncodeCertPEM(certBytes), nil +} + +// Helper functions to encode to PEM +func EncodeCertPEM(cert []byte) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert, + }) +} + +func EncodeKeyPEM(key *rsa.PrivateKey) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) +} + +func EncodeCSRPEM(csr []byte) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csr, + }) +} diff --git a/internal/auth/errors/errors.go b/internal/auth/errors/errors.go index 1fcfe64..0a5a8dc 100644 --- a/internal/auth/errors/errors.go +++ b/internal/auth/errors/errors.go @@ -1,48 +1,48 @@ -package errors - -import ( - "errors" - "fmt" -) - -var ( - // ErrInvalidToken indicates the provided token is invalid or expired - ErrInvalidToken = errors.New("invalid or expired token") - - // ErrUnauthorized indicates the client is not authorized - ErrUnauthorized = errors.New("unauthorized") - - // ErrCertificateRequired indicates missing or invalid certificates - ErrCertificateRequired = errors.New("valid certificates are required") - - // ErrInvalidConfig indicates invalid configuration - ErrInvalidConfig = errors.New("invalid configuration") -) - -// AuthError represents a custom error type for authentication errors -type AuthError struct { - Op string // Operation that failed - Err error // Underlying error -} - -func (e *AuthError) Error() string { - if e.Op != "" { - return fmt.Sprintf("%s: %v", e.Op, e.Err) - } - return e.Err.Error() -} - -func (e *AuthError) Unwrap() error { - return e.Err -} - -// WrapError wraps an error with additional operation context -func WrapError(op string, err error) error { - if err == nil { - return nil - } - return &AuthError{ - Op: op, - Err: err, - } -} +package errors + +import ( + "errors" + "fmt" +) + +var ( + // ErrInvalidToken indicates the provided token is invalid or expired + ErrInvalidToken = errors.New("invalid or expired token") + + // ErrUnauthorized indicates the client is not authorized + ErrUnauthorized = errors.New("unauthorized") + + // ErrCertificateRequired indicates missing or invalid certificates + ErrCertificateRequired = errors.New("valid certificates are required") + + // ErrInvalidConfig indicates invalid configuration + ErrInvalidConfig = errors.New("invalid configuration") +) + +// AuthError represents a custom error type for authentication errors +type AuthError struct { + Op string // Operation that failed + Err error // Underlying error +} + +func (e *AuthError) Error() string { + if e.Op != "" { + return fmt.Sprintf("%s: %v", e.Op, e.Err) + } + return e.Err.Error() +} + +func (e *AuthError) Unwrap() error { + return e.Err +} + +// WrapError wraps an error with additional operation context +func WrapError(op string, err error) error { + if err == nil { + return nil + } + return &AuthError{ + Op: op, + Err: err, + } +} diff --git a/internal/auth/server/proxy_mount.go b/internal/auth/server/proxy_mount.go index 5702dd4..4763b4e 100644 --- a/internal/auth/server/proxy_mount.go +++ b/internal/auth/server/proxy_mount.go @@ -1,80 +1,80 @@ -//go:build linux - -package server - -import ( - "fmt" - "os" - "path/filepath" - "syscall" - - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -const proxyCert = "/etc/proxmox-backup/proxy.pem" -const proxyKey = "/etc/proxmox-backup/proxy.key" - -func (c *Config) Mount() error { - - // Check if something is already mounted at the target path - if utils.IsMounted(proxyCert) { - if err := syscall.Unmount(proxyCert, 0); err != nil { - return fmt.Errorf("failed to unmount existing file: %w", err) - } - } - if utils.IsMounted(proxyKey) { - if err := syscall.Unmount(proxyKey, 0); err != nil { - return fmt.Errorf("failed to unmount existing file: %w", err) - } - } - - // Create backup directory if it doesn't exist - backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") - if err := os.MkdirAll(backupDir, 0755); err != nil { - return fmt.Errorf("failed to create backup directory: %w", err) - } - - // Create backup filename with timestamp - backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(proxyCert))) - backupKeyPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(proxyKey))) - - // Read existing file - original, err := os.ReadFile(proxyCert) - if err != nil { - return fmt.Errorf("failed to read original file: %w", err) - } - originalKey, err := os.ReadFile(proxyKey) - if err != nil { - return fmt.Errorf("failed to read original file: %w", err) - } - - // Create backup - if err := os.WriteFile(backupPath, original, 0644); err != nil { - return fmt.Errorf("failed to create backup: %w", err) - } - if err := os.WriteFile(backupKeyPath, originalKey, 0644); err != nil { - return fmt.Errorf("failed to create backup: %w", err) - } - - // Perform bind mount - if err := syscall.Mount(c.CertFile, proxyCert, "", syscall.MS_BIND, ""); err != nil { - return fmt.Errorf("failed to mount file: %w", err) - } - if err := syscall.Mount(c.KeyFile, proxyKey, "", syscall.MS_BIND, ""); err != nil { - return fmt.Errorf("failed to mount file: %w", err) - } - - return nil -} - -func (c *Config) Unmount() error { - // Unmount the file - if err := syscall.Unmount(proxyCert, 0); err != nil { - return fmt.Errorf("failed to unmount file: %w", err) - } - if err := syscall.Unmount(proxyKey, 0); err != nil { - return fmt.Errorf("failed to unmount file: %w", err) - } - - return nil -} +//go:build linux + +package server + +import ( + "fmt" + "os" + "path/filepath" + "syscall" + + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +const proxyCert = "/etc/proxmox-backup/proxy.pem" +const proxyKey = "/etc/proxmox-backup/proxy.key" + +func (c *Config) Mount() error { + + // Check if something is already mounted at the target path + if utils.IsMounted(proxyCert) { + if err := syscall.Unmount(proxyCert, 0); err != nil { + return fmt.Errorf("failed to unmount existing file: %w", err) + } + } + if utils.IsMounted(proxyKey) { + if err := syscall.Unmount(proxyKey, 0); err != nil { + return fmt.Errorf("failed to unmount existing file: %w", err) + } + } + + // Create backup directory if it doesn't exist + backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") + if err := os.MkdirAll(backupDir, 0755); err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + // Create backup filename with timestamp + backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(proxyCert))) + backupKeyPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(proxyKey))) + + // Read existing file + original, err := os.ReadFile(proxyCert) + if err != nil { + return fmt.Errorf("failed to read original file: %w", err) + } + originalKey, err := os.ReadFile(proxyKey) + if err != nil { + return fmt.Errorf("failed to read original file: %w", err) + } + + // Create backup + if err := os.WriteFile(backupPath, original, 0644); err != nil { + return fmt.Errorf("failed to create backup: %w", err) + } + if err := os.WriteFile(backupKeyPath, originalKey, 0644); err != nil { + return fmt.Errorf("failed to create backup: %w", err) + } + + // Perform bind mount + if err := syscall.Mount(c.CertFile, proxyCert, "", syscall.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to mount file: %w", err) + } + if err := syscall.Mount(c.KeyFile, proxyKey, "", syscall.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to mount file: %w", err) + } + + return nil +} + +func (c *Config) Unmount() error { + // Unmount the file + if err := syscall.Unmount(proxyCert, 0); err != nil { + return fmt.Errorf("failed to unmount file: %w", err) + } + if err := syscall.Unmount(proxyKey, 0); err != nil { + return fmt.Errorf("failed to unmount file: %w", err) + } + + return nil +} diff --git a/internal/auth/server/server.go b/internal/auth/server/server.go index 00b100a..da2ae09 100644 --- a/internal/auth/server/server.go +++ b/internal/auth/server/server.go @@ -1,93 +1,93 @@ -package server - -import ( - "crypto/tls" - "crypto/x509" - "errors" - "os" - "time" - - authErrors "github.com/sonroyaalmerol/pbs-plus/internal/auth/errors" -) - -// Config represents the server configuration -type Config struct { - // Server TLS configuration - CertFile string - KeyFile string - CAFile string - CAKey string - - // Token configuration - TokenExpiration time.Duration - TokenSecret string - - // Server configuration - Address string - ReadTimeout time.Duration - WriteTimeout time.Duration - MaxHeaderBytes int - - // Rate limiting - RateLimit float64 // Requests per second - RateBurst int // Maximum burst size -} - -// DefaultConfig returns a default server configuration -func DefaultConfig() *Config { - return &Config{ - Address: ":8008", - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - MaxHeaderBytes: 1 << 20, // 1MB - - TokenExpiration: 24 * time.Hour, - - RateLimit: 100.0, - RateBurst: 200, - } -} - -// Validate checks if the configuration is valid -func (c *Config) Validate() error { - if c.CertFile == "" || c.KeyFile == "" || c.CAFile == "" { - return authErrors.ErrCertificateRequired - } - - // Check if certificate files exist - files := []string{c.CertFile, c.KeyFile, c.CAFile} - for _, file := range files { - if _, err := os.Stat(file); err != nil { - return authErrors.WrapError("validate_config", err) - } - } - - return nil -} - -// LoadTLSConfig creates a TLS configuration from the server config -func (c *Config) LoadTLSConfig() (*tls.Config, error) { - // Load server certificate - cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) - if err != nil { - return nil, authErrors.WrapError("load_tls_config", err) - } - - // Load CA cert - caCert, err := os.ReadFile(c.CAFile) - if err != nil { - return nil, authErrors.WrapError("load_tls_config", err) - } - - caCertPool := x509.NewCertPool() - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, authErrors.WrapError("load_tls_config", - errors.New("failed to append CA certificate")) - } - - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - ClientCAs: caCertPool, - ClientAuth: tls.VerifyClientCertIfGiven, - }, nil -} +package server + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "os" + "time" + + authErrors "github.com/sonroyaalmerol/pbs-plus/internal/auth/errors" +) + +// Config represents the server configuration +type Config struct { + // Server TLS configuration + CertFile string + KeyFile string + CAFile string + CAKey string + + // Token configuration + TokenExpiration time.Duration + TokenSecret string + + // Server configuration + Address string + ReadTimeout time.Duration + WriteTimeout time.Duration + MaxHeaderBytes int + + // Rate limiting + RateLimit float64 // Requests per second + RateBurst int // Maximum burst size +} + +// DefaultConfig returns a default server configuration +func DefaultConfig() *Config { + return &Config{ + Address: ":8008", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + MaxHeaderBytes: 1 << 20, // 1MB + + TokenExpiration: 24 * time.Hour, + + RateLimit: 100.0, + RateBurst: 200, + } +} + +// Validate checks if the configuration is valid +func (c *Config) Validate() error { + if c.CertFile == "" || c.KeyFile == "" || c.CAFile == "" { + return authErrors.ErrCertificateRequired + } + + // Check if certificate files exist + files := []string{c.CertFile, c.KeyFile, c.CAFile} + for _, file := range files { + if _, err := os.Stat(file); err != nil { + return authErrors.WrapError("validate_config", err) + } + } + + return nil +} + +// LoadTLSConfig creates a TLS configuration from the server config +func (c *Config) LoadTLSConfig() (*tls.Config, error) { + // Load server certificate + cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile) + if err != nil { + return nil, authErrors.WrapError("load_tls_config", err) + } + + // Load CA cert + caCert, err := os.ReadFile(c.CAFile) + if err != nil { + return nil, authErrors.WrapError("load_tls_config", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, authErrors.WrapError("load_tls_config", + errors.New("failed to append CA certificate")) + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientCAs: caCertPool, + ClientAuth: tls.VerifyClientCertIfGiven, + }, nil +} diff --git a/internal/auth/token/manager.go b/internal/auth/token/manager.go index 3ca615b..96fa864 100644 --- a/internal/auth/token/manager.go +++ b/internal/auth/token/manager.go @@ -1,90 +1,90 @@ -package token - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "time" - - "github.com/golang-jwt/jwt" - authErrors "github.com/sonroyaalmerol/pbs-plus/internal/auth/errors" -) - -// Claims represents the JWT claims -type Claims struct { - jwt.StandardClaims -} - -// Manager handles token generation and validation -type Manager struct { - secret []byte - config Config -} - -// Config represents token manager configuration -type Config struct { - // TokenExpiration is the duration for which a token is valid - TokenExpiration time.Duration - // SecretKey is the key used to sign tokens - SecretKey string -} - -// NewManager creates a new token manager -func NewManager(config Config) (*Manager, error) { - if config.SecretKey == "" { - // Generate a random secret if none provided - secret := make([]byte, 32) - if _, err := rand.Read(secret); err != nil { - return nil, authErrors.WrapError("generate_secret", err) - } - config.SecretKey = base64.StdEncoding.EncodeToString(secret) - } - - if config.TokenExpiration == 0 { - config.TokenExpiration = 24 * time.Hour - } - - m := &Manager{ - secret: []byte(config.SecretKey), - config: config, - } - - return m, nil -} - -// GenerateToken creates a new JWT token for an agent -func (m *Manager) GenerateToken() (string, error) { - claims := Claims{ - StandardClaims: jwt.StandardClaims{ - ExpiresAt: time.Now().Add(m.config.TokenExpiration).Unix(), - IssuedAt: time.Now().Unix(), - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString(m.secret) - if err != nil { - return "", authErrors.WrapError("generate_token", err) - } - - return tokenString, nil -} - -// ValidateToken checks if a token is valid -func (m *Manager) ValidateToken(tokenString string) error { - token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, authErrors.WrapError("validate_token", fmt.Errorf("unexpected signing method: %v", token.Header["alg"])) - } - return m.secret, nil - }) - if err != nil { - return authErrors.WrapError("validate_token", err) - } - - if _, ok := token.Claims.(*Claims); ok && token.Valid { - return nil - } - - return authErrors.ErrInvalidToken -} +package token + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "time" + + "github.com/golang-jwt/jwt" + authErrors "github.com/sonroyaalmerol/pbs-plus/internal/auth/errors" +) + +// Claims represents the JWT claims +type Claims struct { + jwt.StandardClaims +} + +// Manager handles token generation and validation +type Manager struct { + secret []byte + config Config +} + +// Config represents token manager configuration +type Config struct { + // TokenExpiration is the duration for which a token is valid + TokenExpiration time.Duration + // SecretKey is the key used to sign tokens + SecretKey string +} + +// NewManager creates a new token manager +func NewManager(config Config) (*Manager, error) { + if config.SecretKey == "" { + // Generate a random secret if none provided + secret := make([]byte, 32) + if _, err := rand.Read(secret); err != nil { + return nil, authErrors.WrapError("generate_secret", err) + } + config.SecretKey = base64.StdEncoding.EncodeToString(secret) + } + + if config.TokenExpiration == 0 { + config.TokenExpiration = 24 * time.Hour + } + + m := &Manager{ + secret: []byte(config.SecretKey), + config: config, + } + + return m, nil +} + +// GenerateToken creates a new JWT token for an agent +func (m *Manager) GenerateToken() (string, error) { + claims := Claims{ + StandardClaims: jwt.StandardClaims{ + ExpiresAt: time.Now().Add(m.config.TokenExpiration).Unix(), + IssuedAt: time.Now().Unix(), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString(m.secret) + if err != nil { + return "", authErrors.WrapError("generate_token", err) + } + + return tokenString, nil +} + +// ValidateToken checks if a token is valid +func (m *Manager) ValidateToken(tokenString string) error { + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, authErrors.WrapError("validate_token", fmt.Errorf("unexpected signing method: %v", token.Header["alg"])) + } + return m.secret, nil + }) + if err != nil { + return authErrors.WrapError("validate_token", err) + } + + if _, ok := token.Claims.(*Claims); ok && token.Valid { + return nil + } + + return authErrors.ErrInvalidToken +} diff --git a/internal/backend/backup/command.go b/internal/backend/backup/command.go index 8b80a07..d6ebb87 100644 --- a/internal/backend/backup/command.go +++ b/internal/backend/backup/command.go @@ -1,128 +1,128 @@ -//go:build linux - -package backup - -import ( - "fmt" - "io" - "os" - "os/exec" - "strings" - - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" -) - -func prepareBackupCommand(job *types.Job, storeInstance *store.Store, srcPath string, isAgent bool) (*exec.Cmd, error) { - if srcPath == "" { - return nil, fmt.Errorf("RunBackup: source path is required") - } - - backupId, err := getBackupId(isAgent, job.Target) - if err != nil { - return nil, fmt.Errorf("RunBackup: failed to get backup ID: %w", err) - } - - jobStore := fmt.Sprintf("%s@localhost:%s", proxmox.Session.APIToken.TokenId, job.Store) - if jobStore == "@localhost:" { - return nil, fmt.Errorf("RunBackup: invalid job store configuration") - } - - cmdArgs := buildCommandArgs(storeInstance, job, srcPath, jobStore, backupId) - if len(cmdArgs) == 0 { - return nil, fmt.Errorf("RunBackup: failed to build command arguments") - } - - cmd := exec.Command("/usr/bin/proxmox-backup-client", cmdArgs...) - cmd.Env = buildCommandEnv(storeInstance) - - return cmd, nil -} - -func getBackupId(isAgent bool, targetName string) (string, error) { - if !isAgent { - hostname, err := os.Hostname() - if err != nil { - hostnameBytes, err := os.ReadFile("/etc/hostname") - if err != nil { - return "localhost", nil - } - return strings.TrimSpace(string(hostnameBytes)), nil - } - return hostname, nil - } - if targetName == "" { - return "", fmt.Errorf("target name is required for agent backup") - } - return strings.TrimSpace(strings.Split(targetName, " - ")[0]), nil -} - -func buildCommandArgs(storeInstance *store.Store, job *types.Job, srcPath string, jobStore string, backupId string) []string { - if srcPath == "" || jobStore == "" || backupId == "" { - return nil - } - - cmdArgs := []string{ - "backup", - fmt.Sprintf("%s.pxar:%s", strings.ReplaceAll(job.Target, " ", "-"), srcPath), - "--repository", jobStore, - "--change-detection-mode=metadata", - "--backup-id", backupId, - "--crypt-mode=none", - "--skip-e2big-xattr", "true", - "--skip-lost-and-found", "true", - } - - // Add exclusions - for _, exclusion := range job.Exclusions { - path := exclusion.Path - if !strings.HasPrefix(exclusion.Path, "/") && !strings.HasPrefix(exclusion.Path, "!") && !strings.HasPrefix(exclusion.Path, "**/") { - path = "**/" + path - } - - cmdArgs = append(cmdArgs, "--exclude", path) - } - - // Add namespace if specified - if job.Namespace != "" { - _ = CreateNamespace(job.Namespace, job, storeInstance) - cmdArgs = append(cmdArgs, "--ns", job.Namespace) - } - - return cmdArgs -} - -func buildCommandEnv(storeInstance *store.Store) []string { - if storeInstance == nil || proxmox.Session.APIToken == nil { - return os.Environ() - } - - env := append(os.Environ(), - fmt.Sprintf("PBS_PASSWORD=%s", proxmox.Session.APIToken.Value)) - // env = append(env, "PBS_LOG=debug") - - // Add fingerprint if available - if pbsStatus, err := proxmox.Session.GetPBSStatus(); err == nil { - if fingerprint, ok := pbsStatus.Info["fingerprint"]; ok { - env = append(env, fmt.Sprintf("PBS_FINGERPRINT=%s", fingerprint)) - } - } - - return env -} - -func setupCommandPipes(cmd *exec.Cmd) (io.ReadCloser, io.ReadCloser, error) { - stdout, err := cmd.StdoutPipe() - if err != nil { - return nil, nil, fmt.Errorf("error creating stdout pipe: %w", err) - } - - stderr, err := cmd.StderrPipe() - if err != nil { - stdout.Close() // Clean up stdout if stderr fails - return nil, nil, fmt.Errorf("error creating stderr pipe: %w", err) - } - - return stdout, stderr, nil -} +//go:build linux + +package backup + +import ( + "fmt" + "io" + "os" + "os/exec" + "strings" + + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" +) + +func prepareBackupCommand(job *types.Job, storeInstance *store.Store, srcPath string, isAgent bool) (*exec.Cmd, error) { + if srcPath == "" { + return nil, fmt.Errorf("RunBackup: source path is required") + } + + backupId, err := getBackupId(isAgent, job.Target) + if err != nil { + return nil, fmt.Errorf("RunBackup: failed to get backup ID: %w", err) + } + + jobStore := fmt.Sprintf("%s@localhost:%s", proxmox.Session.APIToken.TokenId, job.Store) + if jobStore == "@localhost:" { + return nil, fmt.Errorf("RunBackup: invalid job store configuration") + } + + cmdArgs := buildCommandArgs(storeInstance, job, srcPath, jobStore, backupId) + if len(cmdArgs) == 0 { + return nil, fmt.Errorf("RunBackup: failed to build command arguments") + } + + cmd := exec.Command("/usr/bin/proxmox-backup-client", cmdArgs...) + cmd.Env = buildCommandEnv(storeInstance) + + return cmd, nil +} + +func getBackupId(isAgent bool, targetName string) (string, error) { + if !isAgent { + hostname, err := os.Hostname() + if err != nil { + hostnameBytes, err := os.ReadFile("/etc/hostname") + if err != nil { + return "localhost", nil + } + return strings.TrimSpace(string(hostnameBytes)), nil + } + return hostname, nil + } + if targetName == "" { + return "", fmt.Errorf("target name is required for agent backup") + } + return strings.TrimSpace(strings.Split(targetName, " - ")[0]), nil +} + +func buildCommandArgs(storeInstance *store.Store, job *types.Job, srcPath string, jobStore string, backupId string) []string { + if srcPath == "" || jobStore == "" || backupId == "" { + return nil + } + + cmdArgs := []string{ + "backup", + fmt.Sprintf("%s.pxar:%s", strings.ReplaceAll(job.Target, " ", "-"), srcPath), + "--repository", jobStore, + "--change-detection-mode=metadata", + "--backup-id", backupId, + "--crypt-mode=none", + "--skip-e2big-xattr", "true", + "--skip-lost-and-found", "true", + } + + // Add exclusions + for _, exclusion := range job.Exclusions { + path := exclusion.Path + if !strings.HasPrefix(exclusion.Path, "/") && !strings.HasPrefix(exclusion.Path, "!") && !strings.HasPrefix(exclusion.Path, "**/") { + path = "**/" + path + } + + cmdArgs = append(cmdArgs, "--exclude", path) + } + + // Add namespace if specified + if job.Namespace != "" { + _ = CreateNamespace(job.Namespace, job, storeInstance) + cmdArgs = append(cmdArgs, "--ns", job.Namespace) + } + + return cmdArgs +} + +func buildCommandEnv(storeInstance *store.Store) []string { + if storeInstance == nil || proxmox.Session.APIToken == nil { + return os.Environ() + } + + env := append(os.Environ(), + fmt.Sprintf("PBS_PASSWORD=%s", proxmox.Session.APIToken.Value)) + // env = append(env, "PBS_LOG=debug") + + // Add fingerprint if available + if pbsStatus, err := proxmox.Session.GetPBSStatus(); err == nil { + if fingerprint, ok := pbsStatus.Info["fingerprint"]; ok { + env = append(env, fmt.Sprintf("PBS_FINGERPRINT=%s", fingerprint)) + } + } + + return env +} + +func setupCommandPipes(cmd *exec.Cmd) (io.ReadCloser, io.ReadCloser, error) { + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, nil, fmt.Errorf("error creating stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + stdout.Close() // Clean up stdout if stderr fails + return nil, nil, fmt.Errorf("error creating stderr pipe: %w", err) + } + + return stdout, stderr, nil +} diff --git a/internal/backend/backup/logging.go b/internal/backend/backup/logging.go index 34774c7..74d455e 100644 --- a/internal/backend/backup/logging.go +++ b/internal/backend/backup/logging.go @@ -1,114 +1,114 @@ -package backup - -import ( - "bufio" - "fmt" - "io" - "log" - "os" - "strings" - "sync" - "time" - - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -func collectLogs(stdout, stderr io.ReadCloser) ([]string, error) { - defer stdout.Close() - defer stderr.Close() - - linesCh := make(chan string) - errCh := make(chan error, 2) - var wg sync.WaitGroup - - scanner := func(r io.Reader) { - defer wg.Done() - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - log.Println(line) // Log to console - linesCh <- line - } - if err := scanner.Err(); err != nil { - errCh <- fmt.Errorf("error reading logs: %w", err) - } - } - - wg.Add(2) - go scanner(stdout) - go scanner(stderr) - - go func() { - wg.Wait() - close(linesCh) - close(errCh) - }() - - var logLines []string - for line := range linesCh { - logLines = append(logLines, line) - } - - var errs []error - for err := range errCh { - errs = append(errs, err) - } - - if len(errs) > 0 { - return nil, fmt.Errorf("errors reading logs: %v", errs) - } - - return logLines, nil -} - -func writeLogsToFile(upid string, logLines []string) error { - if logLines == nil { - return fmt.Errorf("logLines is nil") - } - - if err := utils.WaitForLogFile(upid, 1*time.Minute); err != nil { - return fmt.Errorf("log file cannot be opened: %w", err) - } - - time.Sleep(5 * time.Second) - - logFilePath := utils.GetTaskLogPath(upid) - logFile, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - return fmt.Errorf("failed to open log file: %w", err) - } - defer logFile.Close() - - writer := bufio.NewWriter(logFile) - defer writer.Flush() - - if _, err := writer.WriteString("--- proxmox-backup-client log starts here ---\n"); err != nil { - return fmt.Errorf("failed to write log header: %w", err) - } - - hasError := false - var errorString string - timestamp := time.Now().Format(time.RFC3339) - - for _, logLine := range logLines { - if strings.Contains(logLine, "Error: upload failed:") { - errorString = strings.Replace(logLine, "Error:", "TASK ERROR:", 1) - hasError = true - continue - } - if _, err := writer.WriteString(fmt.Sprintf("%s: %s\n", timestamp, logLine)); err != nil { - return fmt.Errorf("failed to write log line: %w", err) - } - } - - finalStatus := fmt.Sprintf("%s: TASK OK\n", timestamp) - if hasError { - finalStatus = fmt.Sprintf("%s: %s\n", timestamp, errorString) - } - - if _, err := writer.WriteString(finalStatus); err != nil { - return fmt.Errorf("failed to write final status: %w", err) - } - - return nil -} +package backup + +import ( + "bufio" + "fmt" + "io" + "log" + "os" + "strings" + "sync" + "time" + + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +func collectLogs(stdout, stderr io.ReadCloser) ([]string, error) { + defer stdout.Close() + defer stderr.Close() + + linesCh := make(chan string) + errCh := make(chan error, 2) + var wg sync.WaitGroup + + scanner := func(r io.Reader) { + defer wg.Done() + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + log.Println(line) // Log to console + linesCh <- line + } + if err := scanner.Err(); err != nil { + errCh <- fmt.Errorf("error reading logs: %w", err) + } + } + + wg.Add(2) + go scanner(stdout) + go scanner(stderr) + + go func() { + wg.Wait() + close(linesCh) + close(errCh) + }() + + var logLines []string + for line := range linesCh { + logLines = append(logLines, line) + } + + var errs []error + for err := range errCh { + errs = append(errs, err) + } + + if len(errs) > 0 { + return nil, fmt.Errorf("errors reading logs: %v", errs) + } + + return logLines, nil +} + +func writeLogsToFile(upid string, logLines []string) error { + if logLines == nil { + return fmt.Errorf("logLines is nil") + } + + if err := utils.WaitForLogFile(upid, 1*time.Minute); err != nil { + return fmt.Errorf("log file cannot be opened: %w", err) + } + + time.Sleep(5 * time.Second) + + logFilePath := utils.GetTaskLogPath(upid) + logFile, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + return fmt.Errorf("failed to open log file: %w", err) + } + defer logFile.Close() + + writer := bufio.NewWriter(logFile) + defer writer.Flush() + + if _, err := writer.WriteString("--- proxmox-backup-client log starts here ---\n"); err != nil { + return fmt.Errorf("failed to write log header: %w", err) + } + + hasError := false + var errorString string + timestamp := time.Now().Format(time.RFC3339) + + for _, logLine := range logLines { + if strings.Contains(logLine, "Error: upload failed:") { + errorString = strings.Replace(logLine, "Error:", "TASK ERROR:", 1) + hasError = true + continue + } + if _, err := writer.WriteString(fmt.Sprintf("%s: %s\n", timestamp, logLine)); err != nil { + return fmt.Errorf("failed to write log line: %w", err) + } + } + + finalStatus := fmt.Sprintf("%s: TASK OK\n", timestamp) + if hasError { + finalStatus = fmt.Sprintf("%s: %s\n", timestamp, errorString) + } + + if _, err := writer.WriteString(finalStatus); err != nil { + return fmt.Errorf("failed to write final status: %w", err) + } + + return nil +} diff --git a/internal/backend/backup/status.go b/internal/backend/backup/status.go index 391b820..eb7366f 100644 --- a/internal/backend/backup/status.go +++ b/internal/backend/backup/status.go @@ -1,37 +1,37 @@ -//go:build linux - -package backup - -import ( - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" -) - -func updateJobStatus(job *types.Job, task *proxmox.Task, storeInstance *store.Store) error { - // Update task status - taskFound, err := proxmox.Session.GetTaskByUPID(task.UPID) - if err != nil { - syslog.L.Errorf("Unable to get task by UPID: %v", err) - return err - } - - // Update job status - latestJob, err := storeInstance.Database.GetJob(job.ID) - if err != nil { - syslog.L.Errorf("Unable to get job: %v", err) - return err - } - - latestJob.LastRunUpid = taskFound.UPID - latestJob.LastRunState = &taskFound.Status - latestJob.LastRunEndtime = &taskFound.EndTime - - if err := storeInstance.Database.UpdateJob(*latestJob); err != nil { - syslog.L.Errorf("Unable to update job: %v", err) - return err - } - - return nil -} +//go:build linux + +package backup + +import ( + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" +) + +func updateJobStatus(job *types.Job, task *proxmox.Task, storeInstance *store.Store) error { + // Update task status + taskFound, err := proxmox.Session.GetTaskByUPID(task.UPID) + if err != nil { + syslog.L.Errorf("Unable to get task by UPID: %v", err) + return err + } + + // Update job status + latestJob, err := storeInstance.Database.GetJob(job.ID) + if err != nil { + syslog.L.Errorf("Unable to get job: %v", err) + return err + } + + latestJob.LastRunUpid = taskFound.UPID + latestJob.LastRunState = &taskFound.Status + latestJob.LastRunEndtime = &taskFound.EndTime + + if err := storeInstance.Database.UpdateJob(*latestJob); err != nil { + syslog.L.Errorf("Unable to update job: %v", err) + return err + } + + return nil +} diff --git a/internal/backend/mount/mount.go b/internal/backend/mount/mount.go index 8cc41dc..885fef5 100644 --- a/internal/backend/mount/mount.go +++ b/internal/backend/mount/mount.go @@ -1,139 +1,139 @@ -//go:build linux - -package mount - -import ( - "encoding/base32" - "fmt" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" - "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -type AgentMount struct { - Hostname string - Drive string - Path string - Cmd *exec.Cmd -} - -func Mount(storeInstance *store.Store, target *types.Target) (*AgentMount, error) { - // Parse target information - splittedTargetName := strings.Split(target.Name, " - ") - targetHostname := splittedTargetName[0] - agentPath := strings.TrimPrefix(target.Path, "agent://") - agentPathParts := strings.Split(agentPath, "/") - agentHost := agentPathParts[0] - agentDrive := agentPathParts[1] - - // Encode hostname and drive for API call - targetHostnameEnc := base32.StdEncoding.EncodeToString([]byte(targetHostname)) - agentDriveEnc := base32.StdEncoding.EncodeToString([]byte(agentDrive)) - - // Request mount from agent - err := proxmox.Session.ProxmoxHTTPRequest( - http.MethodPost, - fmt.Sprintf("https://localhost:8008/plus/mount/%s/%s", targetHostnameEnc, agentDriveEnc), - nil, - nil, - ) - if err != nil { - return nil, fmt.Errorf("Mount: Failed to send mount request to target '%s' -> %w", target.Name, err) - } - - agentMount := &AgentMount{ - Hostname: targetHostname, - Drive: agentDrive, - } - - // Get port for NFS connection - agentDriveRune := []rune(agentDrive)[0] - agentPort, err := utils.DriveLetterPort(agentDriveRune) - if err != nil { - agentMount.Unmount() - return nil, fmt.Errorf("Mount: error mapping \"%c\" to network port -> %w", agentDriveRune, err) - } - - // Setup mount path - agentMount.Path = filepath.Join(constants.AgentMountBasePath, strings.ReplaceAll(target.Name, " ", "-")) - agentMount.Unmount() // Ensure clean mount point - - // Create mount directory if it doesn't exist - err = os.MkdirAll(agentMount.Path, 0700) - if err != nil { - return nil, fmt.Errorf("Mount: error creating directory \"%s\" -> %w", agentMount.Path, err) - } - - // Mount using NFS - mountArgs := []string{ - "-t", "nfs", - "-o", fmt.Sprintf("port=%s,mountport=%s,vers=3,soft,timeo=100,ro,tcp,noacl,nocto,actimeo=3600,rsize=1048576,lookupcache=positive,noatime", agentPort, agentPort), - fmt.Sprintf("%s:/", agentHost), - agentMount.Path, - } - - // Mount the NFS share - mnt := exec.Command("mount", mountArgs...) - mnt.Env = os.Environ() - mnt.Stdout = os.Stdout - mnt.Stderr = os.Stderr - agentMount.Cmd = mnt - - // Try mounting with retries - const maxRetries = 3 - const retryDelay = 2 * time.Second - - var lastErr error - for i := 0; i < maxRetries; i++ { - err = mnt.Run() - if err == nil { - return agentMount, nil - } - lastErr = err - if i < maxRetries-1 { - time.Sleep(retryDelay) - } - } - - // If all retries failed, clean up and return error - agentMount.Unmount() - agentMount.CloseMount() - return nil, fmt.Errorf("Mount: error mounting NFS share after %d attempts -> %w", maxRetries, lastErr) -} - -func (a *AgentMount) Unmount() { - if a.Path == "" { - return - } - - // First try a clean unmount - umount := exec.Command("umount", "-lf", a.Path) - umount.Env = os.Environ() - _ = umount.Run() - - // Kill any lingering mount process - if a.Cmd != nil && a.Cmd.Process != nil { - _ = a.Cmd.Process.Kill() - } -} - -func (a *AgentMount) CloseMount() { - targetHostnameEnc := base32.StdEncoding.EncodeToString([]byte(a.Hostname)) - agentDriveEnc := base32.StdEncoding.EncodeToString([]byte(a.Drive)) - - _ = proxmox.Session.ProxmoxHTTPRequest( - http.MethodDelete, - fmt.Sprintf("https://localhost:8008/plus/mount/%s/%s", targetHostnameEnc, agentDriveEnc), - nil, - nil, - ) -} +//go:build linux + +package mount + +import ( + "encoding/base32" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" + "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +type AgentMount struct { + Hostname string + Drive string + Path string + Cmd *exec.Cmd +} + +func Mount(storeInstance *store.Store, target *types.Target) (*AgentMount, error) { + // Parse target information + splittedTargetName := strings.Split(target.Name, " - ") + targetHostname := splittedTargetName[0] + agentPath := strings.TrimPrefix(target.Path, "agent://") + agentPathParts := strings.Split(agentPath, "/") + agentHost := agentPathParts[0] + agentDrive := agentPathParts[1] + + // Encode hostname and drive for API call + targetHostnameEnc := base32.StdEncoding.EncodeToString([]byte(targetHostname)) + agentDriveEnc := base32.StdEncoding.EncodeToString([]byte(agentDrive)) + + // Request mount from agent + err := proxmox.Session.ProxmoxHTTPRequest( + http.MethodPost, + fmt.Sprintf("https://localhost:8008/plus/mount/%s/%s", targetHostnameEnc, agentDriveEnc), + nil, + nil, + ) + if err != nil { + return nil, fmt.Errorf("Mount: Failed to send mount request to target '%s' -> %w", target.Name, err) + } + + agentMount := &AgentMount{ + Hostname: targetHostname, + Drive: agentDrive, + } + + // Get port for NFS connection + agentDriveRune := []rune(agentDrive)[0] + agentPort, err := utils.DriveLetterPort(agentDriveRune) + if err != nil { + agentMount.Unmount() + return nil, fmt.Errorf("Mount: error mapping \"%c\" to network port -> %w", agentDriveRune, err) + } + + // Setup mount path + agentMount.Path = filepath.Join(constants.AgentMountBasePath, strings.ReplaceAll(target.Name, " ", "-")) + agentMount.Unmount() // Ensure clean mount point + + // Create mount directory if it doesn't exist + err = os.MkdirAll(agentMount.Path, 0700) + if err != nil { + return nil, fmt.Errorf("Mount: error creating directory \"%s\" -> %w", agentMount.Path, err) + } + + // Mount using NFS + mountArgs := []string{ + "-t", "nfs", + "-o", fmt.Sprintf("port=%s,mountport=%s,vers=3,soft,timeo=100,ro,tcp,noacl,nocto,actimeo=3600,rsize=1048576,lookupcache=positive,noatime", agentPort, agentPort), + fmt.Sprintf("%s:/", agentHost), + agentMount.Path, + } + + // Mount the NFS share + mnt := exec.Command("mount", mountArgs...) + mnt.Env = os.Environ() + mnt.Stdout = os.Stdout + mnt.Stderr = os.Stderr + agentMount.Cmd = mnt + + // Try mounting with retries + const maxRetries = 3 + const retryDelay = 2 * time.Second + + var lastErr error + for i := 0; i < maxRetries; i++ { + err = mnt.Run() + if err == nil { + return agentMount, nil + } + lastErr = err + if i < maxRetries-1 { + time.Sleep(retryDelay) + } + } + + // If all retries failed, clean up and return error + agentMount.Unmount() + agentMount.CloseMount() + return nil, fmt.Errorf("Mount: error mounting NFS share after %d attempts -> %w", maxRetries, lastErr) +} + +func (a *AgentMount) Unmount() { + if a.Path == "" { + return + } + + // First try a clean unmount + umount := exec.Command("umount", "-lf", a.Path) + umount.Env = os.Environ() + _ = umount.Run() + + // Kill any lingering mount process + if a.Cmd != nil && a.Cmd.Process != nil { + _ = a.Cmd.Process.Kill() + } +} + +func (a *AgentMount) CloseMount() { + targetHostnameEnc := base32.StdEncoding.EncodeToString([]byte(a.Hostname)) + agentDriveEnc := base32.StdEncoding.EncodeToString([]byte(a.Drive)) + + _ = proxmox.Session.ProxmoxHTTPRequest( + http.MethodDelete, + fmt.Sprintf("https://localhost:8008/plus/mount/%s/%s", targetHostnameEnc, agentDriveEnc), + nil, + nil, + ) +} diff --git a/internal/config/config_tags.go b/internal/config/config_tags.go index 66e8cf6..09a0419 100644 --- a/internal/config/config_tags.go +++ b/internal/config/config_tags.go @@ -1,125 +1,125 @@ -//go:build linux - -package config - -import ( - "fmt" - "reflect" - "strconv" - "strings" -) - -// ConfigTag represents the parsed configuration tags from struct fields -type ConfigTag struct { - Type PropertyType - Key string - Required bool - MinLength *int - MaxLength *int - IsID bool // New field to mark which field is the section ID -} - -// parseConfigTags parses struct field tags to extract configuration metadata -func parseConfigTags(tag string) (ConfigTag, error) { - result := ConfigTag{} - - if tag == "" { - return result, nil - } - - tags := strings.Split(tag, ",") - for _, t := range tags { - parts := strings.SplitN(t, "=", 2) - key := parts[0] - - switch key { - case "key": - if len(parts) != 2 { - return result, fmt.Errorf("type key requires a value") - } - result.Key = strings.TrimSpace(parts[1]) - case "type": - if len(parts) != 2 { - return result, fmt.Errorf("type tag requires a value") - } - result.Type = PropertyType(parts[1]) - case "required": - result.Required = true - case "min": - if len(parts) != 2 { - return result, fmt.Errorf("min tag requires a value") - } - val, err := strconv.Atoi(parts[1]) - if err != nil { - return result, fmt.Errorf("invalid min value: %w", err) - } - result.MinLength = &val - case "max": - if len(parts) != 2 { - return result, fmt.Errorf("max tag requires a value") - } - val, err := strconv.Atoi(parts[1]) - if err != nil { - return result, fmt.Errorf("invalid max value: %w", err) - } - result.MaxLength = &val - case "id": - result.IsID = true - } - } - - return result, nil -} - -// validateFieldWithTags validates a field value against its configuration tags -func validateFieldWithTags(value interface{}, tags ConfigTag) error { - switch tags.Type { - case TypeString: - str, ok := value.(string) - if !ok && tags.Required { - return fmt.Errorf("expected string value") - } - - if tags.Required && str == "" { - return fmt.Errorf("required field is empty") - } - - if tags.MinLength != nil && len(str) < *tags.MinLength { - return fmt.Errorf("value length %d is less than minimum %d", len(str), *tags.MinLength) - } - - if tags.MaxLength != nil && len(str) > *tags.MaxLength { - return fmt.Errorf("value length %d is greater than maximum %d", len(str), *tags.MaxLength) - } - - case TypeInt: - _, ok := value.(int) - if !ok && tags.Required { - return fmt.Errorf("expected integer value") - } - - case TypeBool: - _, ok := value.(bool) - if !ok && tags.Required { - return fmt.Errorf("expected boolean value") - } - - case TypeArray: - val := reflect.ValueOf(value) - if val.Kind() != reflect.Slice { - return fmt.Errorf("expected array value") - } - length := val.Len() - if tags.Required && length == 0 { - return fmt.Errorf("required array is empty") - } - if tags.MinLength != nil && length < *tags.MinLength { - return fmt.Errorf("array length %d is less than minimum %d", length, *tags.MinLength) - } - if tags.MaxLength != nil && length > *tags.MaxLength { - return fmt.Errorf("array length %d is greater than maximum %d", length, *tags.MaxLength) - } - } - - return nil -} +//go:build linux + +package config + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// ConfigTag represents the parsed configuration tags from struct fields +type ConfigTag struct { + Type PropertyType + Key string + Required bool + MinLength *int + MaxLength *int + IsID bool // New field to mark which field is the section ID +} + +// parseConfigTags parses struct field tags to extract configuration metadata +func parseConfigTags(tag string) (ConfigTag, error) { + result := ConfigTag{} + + if tag == "" { + return result, nil + } + + tags := strings.Split(tag, ",") + for _, t := range tags { + parts := strings.SplitN(t, "=", 2) + key := parts[0] + + switch key { + case "key": + if len(parts) != 2 { + return result, fmt.Errorf("type key requires a value") + } + result.Key = strings.TrimSpace(parts[1]) + case "type": + if len(parts) != 2 { + return result, fmt.Errorf("type tag requires a value") + } + result.Type = PropertyType(parts[1]) + case "required": + result.Required = true + case "min": + if len(parts) != 2 { + return result, fmt.Errorf("min tag requires a value") + } + val, err := strconv.Atoi(parts[1]) + if err != nil { + return result, fmt.Errorf("invalid min value: %w", err) + } + result.MinLength = &val + case "max": + if len(parts) != 2 { + return result, fmt.Errorf("max tag requires a value") + } + val, err := strconv.Atoi(parts[1]) + if err != nil { + return result, fmt.Errorf("invalid max value: %w", err) + } + result.MaxLength = &val + case "id": + result.IsID = true + } + } + + return result, nil +} + +// validateFieldWithTags validates a field value against its configuration tags +func validateFieldWithTags(value interface{}, tags ConfigTag) error { + switch tags.Type { + case TypeString: + str, ok := value.(string) + if !ok && tags.Required { + return fmt.Errorf("expected string value") + } + + if tags.Required && str == "" { + return fmt.Errorf("required field is empty") + } + + if tags.MinLength != nil && len(str) < *tags.MinLength { + return fmt.Errorf("value length %d is less than minimum %d", len(str), *tags.MinLength) + } + + if tags.MaxLength != nil && len(str) > *tags.MaxLength { + return fmt.Errorf("value length %d is greater than maximum %d", len(str), *tags.MaxLength) + } + + case TypeInt: + _, ok := value.(int) + if !ok && tags.Required { + return fmt.Errorf("expected integer value") + } + + case TypeBool: + _, ok := value.(bool) + if !ok && tags.Required { + return fmt.Errorf("expected boolean value") + } + + case TypeArray: + val := reflect.ValueOf(value) + if val.Kind() != reflect.Slice { + return fmt.Errorf("expected array value") + } + length := val.Len() + if tags.Required && length == 0 { + return fmt.Errorf("required array is empty") + } + if tags.MinLength != nil && length < *tags.MinLength { + return fmt.Errorf("array length %d is less than minimum %d", length, *tags.MinLength) + } + if tags.MaxLength != nil && length > *tags.MaxLength { + return fmt.Errorf("array length %d is greater than maximum %d", length, *tags.MaxLength) + } + } + + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2c98464..69aab38 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,578 +1,578 @@ -//go:build linux - -package config - -import ( - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/sonroyaalmerol/pbs-plus/internal/utils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// Test struct definitions -type BasicTestConfig struct { - Name string `config:"type=string,required"` - Value string `config:"type=string,required"` - Optional string `config:"type=string"` -} - -type ArrayTestConfig struct { - Tags []string `config:"type=array,required"` -} - -type ValidationTestConfig struct { - Email string `config:"type=string,required"` -} - -type CompatTestConfig struct { - Path string `config:"type=string,required"` - Comment string `config:"type=string"` - Count int `config:"type=int"` - Enabled bool `config:"type=bool"` - Tags []string `config:"type=array"` -} - -type CustomKeyTestConfig struct { - ServerName string `config:"type=string,required,key=server_name"` - MaxConn int `config:"type=int,key=max_connections"` - EnableLogs bool `config:"type=bool,key=enable_logging"` - Categories []string `config:"type=array,key=category_list"` - Description string `config:"type=string"` // No custom key - should use field name lowercase -} - -func TestSectionConfig_BasicOperations(t *testing.T) { - // Setup - tempDir := t.TempDir() - - // Create plugin with basic config - testPlugin := &SectionPlugin[BasicTestConfig]{ - TypeName: "test", - FolderPath: tempDir, - Validate: nil, - } - config := NewSectionConfig(testPlugin) - - t.Run("Create and Read", func(t *testing.T) { - testFile := filepath.Join(tempDir, utils.EncodePath("test-basic-cr")+".cfg") - testData := &ConfigData[BasicTestConfig]{ - Sections: map[string]*Section[BasicTestConfig]{ - "test-basic-cr": { - Type: "test", - ID: "test-basic-cr", - Properties: BasicTestConfig{ - Name: "Test 1", - Value: "Value 1", - }, - }, - }, - Order: []string{"test-basic-cr"}, - } - - // Write config - err := config.Write(testData) - require.NoError(t, err) - - // Read config - readData, err := config.Parse(testFile) - require.NoError(t, err) - - // Verify data - assert.Equal(t, testData.Order, readData.Order) - assert.Equal(t, testData.Sections["test-basic-cr"].Properties.Name, - readData.Sections["test-basic-cr"].Properties.Name) - }) - - t.Run("Missing Required Property", func(t *testing.T) { - testData := &ConfigData[BasicTestConfig]{ - Sections: map[string]*Section[BasicTestConfig]{ - "test-missing": { - Type: "test", - ID: "test-missing", - Properties: BasicTestConfig{ - Name: "Test 1", - // Missing required Value - }, - }, - }, - Order: []string{"test-missing"}, - } - - err := config.Write(testData) - assert.Error(t, err) - }) -} - -func TestSectionConfig_ArrayProperties(t *testing.T) { - // Setup - tempDir := t.TempDir() - - // Create plugin with array config - arrayPlugin := &SectionPlugin[ArrayTestConfig]{ - TypeName: "array-test", - FolderPath: tempDir, - Validate: nil, - } - config := NewSectionConfig(arrayPlugin) - - t.Run("Array Property Handling", func(t *testing.T) { - testFile := filepath.Join(tempDir, utils.EncodePath("test-array")+".cfg") - testData := &ConfigData[ArrayTestConfig]{ - Sections: map[string]*Section[ArrayTestConfig]{ - "test-array": { - Type: "array-test", - ID: "test-array", - Properties: ArrayTestConfig{ - Tags: []string{"tag1", "tag2", "tag3"}, - }, - }, - }, - Order: []string{"test-array"}, - } - - err := config.Write(testData) - require.NoError(t, err) - - readData, err := config.Parse(testFile) - require.NoError(t, err) - - assert.Equal(t, testData.Sections["test-array"].Properties.Tags, - readData.Sections["test-array"].Properties.Tags) - }) -} - -func TestSectionConfig_ValidationRules(t *testing.T) { - // Setup - tempDir := t.TempDir() - - // Create plugin with validation config - validationPlugin := &SectionPlugin[ValidationTestConfig]{ - TypeName: "validation-test", - FolderPath: tempDir, - Validate: func(c ValidationTestConfig) error { - if len(c.Email) > 254 { - return fmt.Errorf("email too long") - } - return nil - }, - } - config := NewSectionConfig(validationPlugin) - - t.Run("Valid Pattern", func(t *testing.T) { - testData := &ConfigData[ValidationTestConfig]{ - Sections: map[string]*Section[ValidationTestConfig]{ - "test-validate": { - Type: "validation-test", - ID: "test-validate", - Properties: ValidationTestConfig{ - Email: "test@example.com", - }, - }, - }, - Order: []string{"test-validate"}, - } - - err := config.Write(testData) - require.NoError(t, err) - }) - - t.Run("Email Too Long", func(t *testing.T) { - longEmail := "very-long-email" - for i := 0; i < 250; i++ { - longEmail += "x" - } - longEmail += "@example.com" - - testData := &ConfigData[ValidationTestConfig]{ - Sections: map[string]*Section[ValidationTestConfig]{ - "test-long-email": { - Type: "validation-test", - ID: "test-long-email", - Properties: ValidationTestConfig{ - Email: longEmail, - }, - }, - }, - Order: []string{"test-long-email"}, - } - - err := config.Write(testData) - assert.Error(t, err) - }) -} - -// Test edge cases from old format -func TestEdgeCaseCompatibility(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "edge.cfg") - - // Create a config file with edge cases - oldFormatConfig := `test: edge-case - path /test/path - comment - count - enabled false - tags - -` - err := os.WriteFile(configPath, []byte(oldFormatConfig), 0644) - require.NoError(t, err) - - plugin := &SectionPlugin[CompatTestConfig]{ - TypeName: "test", - FolderPath: tempDir, - } - config := NewSectionConfig(plugin) - - // Read and verify edge cases - configData, err := config.Parse(configPath) - require.NoError(t, err) - - section := configData.Sections["edge-case"] - props := section.Properties - - // Verify empty fields are handled correctly - require.Equal(t, "/test/path", props.Path) - require.Equal(t, "", props.Comment) - require.Equal(t, 0, props.Count) - require.False(t, props.Enabled) - require.Empty(t, props.Tags) - - // Write and verify it maintains format - err = config.Write(configData) - require.NoError(t, err) -} - -func TestFormatCompatibility(t *testing.T) { - tempDir := t.TempDir() - - tests := []struct { - name string - oldConfig string - expectedPath string - expectedCount int - expectedTags []string - expectedOutput string - }{ - { - name: "Tab Indentation", - oldConfig: `test: tab-indent - path /test/path - count 42 - tags tag1,tag2 - -`, - expectedPath: "/test/path", - expectedCount: 42, - expectedTags: []string{"tag1", "tag2"}, - expectedOutput: `test: tab-indent - count 42 - path /test/path - tags tag1,tag2 - -`, - }, - { - name: "Space Indentation", - oldConfig: `test: space-indent - path /test/path - count 42 - tags tag1,tag2 - -`, - expectedPath: "/test/path", - expectedCount: 42, - expectedTags: []string{"tag1", "tag2"}, - expectedOutput: `test: space-indent - count 42 - path /test/path - tags tag1,tag2 - -`, - }, - { - name: "Mixed Whitespace", - oldConfig: `test: mixed-ws - path /test/path - count 42 - tags tag1, tag2 - -`, - expectedPath: "/test/path", - expectedCount: 42, - expectedTags: []string{"tag1", "tag2"}, - expectedOutput: `test: mixed-ws - count 42 - path /test/path - tags tag1,tag2 - -`, - }, - { - name: "Multiple Sections", - oldConfig: `test: section1 - path /path1 - count 1 - -test: section2 - path /path2 - count 2 - -`, - expectedPath: "/path1", - expectedCount: 1, - expectedTags: nil, - expectedOutput: `test: section1 - count 1 - path /path1 - -test: section2 - count 2 - path /path2 - -`, - }, - { - name: "Empty Values", - oldConfig: `test: empty-values - path /test/path - count 0 - tags - enabled false - -`, - expectedPath: "/test/path", - expectedCount: 0, - expectedTags: nil, - expectedOutput: `test: empty-values - path /test/path - tags - -`, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - plugin := &SectionPlugin[CompatTestConfig]{ - TypeName: "test", - FolderPath: tempDir, - } - config := NewSectionConfig(plugin) - - configPath := filepath.Join(tempDir, fmt.Sprintf("%s.cfg", tc.name)) - err := os.WriteFile(configPath, []byte(tc.oldConfig), 0644) - require.NoError(t, err) - - configData, err := config.Parse(configPath) - require.NoError(t, err) - - var firstSectionID string - for id := range configData.Sections { - firstSectionID = id - break - } - section := configData.Sections[firstSectionID] - require.NotNil(t, section) - - assert.Equal(t, tc.expectedPath, section.Properties.Path) - assert.Equal(t, tc.expectedCount, section.Properties.Count) - - if tc.expectedTags == nil { - assert.Empty(t, section.Properties.Tags) - } else { - assert.Equal(t, tc.expectedTags, section.Properties.Tags) - } - - err = config.Write(configData) - require.NoError(t, err) - - written, err := os.ReadFile(configPath) - require.NoError(t, err) - assert.Equal(t, tc.expectedOutput, string(written)) - - // Verify the written file can be parsed again - _, err = config.Parse(configPath) - require.NoError(t, err, "Written file should be parseable") - }) - } -} - -// TestCrossImplementationRoundTrip tests that configs can be written by old implementation -// and read by new, and vice versa -func TestCrossImplementationRoundTrip(t *testing.T) { - tempDir := t.TempDir() - err := os.MkdirAll(tempDir, 0750) - require.NoError(t, err) - - configPath := filepath.Join(tempDir, "roundtrip.cfg") - - testConfig := &ConfigData[CompatTestConfig]{ - FilePath: configPath, // Set the filepath explicitly - Sections: map[string]*Section[CompatTestConfig]{ - "test-section": { - Type: "test", - ID: "test-section", - Properties: CompatTestConfig{ - Path: "/complex/path with spaces", - Comment: "Multi word\tcomment with\ttabs", - Count: 42, - Enabled: true, - Tags: []string{"tag1", "tag with space"}, - }, - }, - }, - Order: []string{"test-section"}, - } - - plugin := &SectionPlugin[CompatTestConfig]{ - TypeName: "test", - FolderPath: tempDir, - } - config := NewSectionConfig(plugin) - - // Write with new implementation - err = config.Write(testConfig) - require.NoError(t, err) - - // Verify file exists - _, err = os.Stat(configPath) - require.NoError(t, err, "Config file should exist") - - // Read the config back - readConfig, err := config.Parse(configPath) - require.NoError(t, err) - - // Verify all fields match exactly - original := testConfig.Sections["test-section"].Properties - read := readConfig.Sections["test-section"].Properties - - assert.Equal(t, original.Path, read.Path) - assert.Equal(t, original.Comment, read.Comment) - assert.Equal(t, original.Count, read.Count) - assert.Equal(t, original.Enabled, read.Enabled) - assert.Equal(t, original.Tags, read.Tags) - - // Verify section order is preserved - assert.Equal(t, testConfig.Order, readConfig.Order) -} - -func TestCustomKeyConfig(t *testing.T) { - tempDir := t.TempDir() - - plugin := &SectionPlugin[CustomKeyTestConfig]{ - TypeName: "server", - FolderPath: tempDir, - } - config := NewSectionConfig(plugin) - - t.Run("Write and Read Custom Keys", func(t *testing.T) { - testFile := filepath.Join(tempDir, utils.EncodePath("test-custom-keys")+".cfg") - - // Create test data with all fields populated - testData := &ConfigData[CustomKeyTestConfig]{ - FilePath: testFile, - Sections: map[string]*Section[CustomKeyTestConfig]{ - "test-custom-keys": { - Type: "server", - ID: "test-custom-keys", - Properties: CustomKeyTestConfig{ - ServerName: "TestServer", - MaxConn: 100, - EnableLogs: true, - Categories: []string{"web", "api", "backend"}, - Description: "Test server configuration", - }, - }, - }, - Order: []string{"test-custom-keys"}, - } - - // Write config - err := config.Write(testData) - require.NoError(t, err) - - // Verify the written file contains custom keys - content, err := os.ReadFile(testFile) - require.NoError(t, err) - - contentStr := string(content) - assert.Contains(t, contentStr, "server_name TestServer") - assert.Contains(t, contentStr, "max_connections 100") - assert.Contains(t, contentStr, "enable_logging true") - assert.Contains(t, contentStr, "category_list web,api,backend") - assert.Contains(t, contentStr, "description Test server configuration") - - // Read config back - readData, err := config.Parse(testFile) - require.NoError(t, err) - - // Verify all fields were correctly read back - readProps := readData.Sections["test-custom-keys"].Properties - assert.Equal(t, "TestServer", readProps.ServerName) - assert.Equal(t, 100, readProps.MaxConn) - assert.True(t, readProps.EnableLogs) - assert.Equal(t, []string{"web", "api", "backend"}, readProps.Categories) - assert.Equal(t, "Test server configuration", readProps.Description) - }) - - t.Run("Partial Fields", func(t *testing.T) { - testFile := filepath.Join(tempDir, utils.EncodePath("test-partial-keys")+".cfg") - - // Create config with only required and some optional fields - testData := &ConfigData[CustomKeyTestConfig]{ - FilePath: testFile, - Sections: map[string]*Section[CustomKeyTestConfig]{ - "test-partial-keys": { - Type: "server", - ID: "test-partial-keys", - Properties: CustomKeyTestConfig{ - ServerName: "MinimalServer", // Required field - MaxConn: 50, // Optional with custom key - }, - }, - }, - Order: []string{"test-partial-keys"}, - } - - // Write config - err := config.Write(testData) - require.NoError(t, err) - - // Read config back - readData, err := config.Parse(testFile) - require.NoError(t, err) - - // Verify fields - readProps := readData.Sections["test-partial-keys"].Properties - assert.Equal(t, "MinimalServer", readProps.ServerName) - assert.Equal(t, 50, readProps.MaxConn) - assert.False(t, readProps.EnableLogs) - assert.Empty(t, readProps.Categories) - assert.Empty(t, readProps.Description) - }) - - t.Run("Missing Required Field", func(t *testing.T) { - testData := &ConfigData[CustomKeyTestConfig]{ - Sections: map[string]*Section[CustomKeyTestConfig]{ - "test-missing-required": { - Type: "server", - ID: "test-missing-required", - Properties: CustomKeyTestConfig{ - // Missing ServerName which is required - MaxConn: 100, - }, - }, - }, - Order: []string{"test-missing-required"}, - } - - // Should fail validation - err := config.Write(testData) - assert.Error(t, err) - assert.Contains(t, err.Error(), "is empty") - }) -} +//go:build linux + +package config + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/sonroyaalmerol/pbs-plus/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test struct definitions +type BasicTestConfig struct { + Name string `config:"type=string,required"` + Value string `config:"type=string,required"` + Optional string `config:"type=string"` +} + +type ArrayTestConfig struct { + Tags []string `config:"type=array,required"` +} + +type ValidationTestConfig struct { + Email string `config:"type=string,required"` +} + +type CompatTestConfig struct { + Path string `config:"type=string,required"` + Comment string `config:"type=string"` + Count int `config:"type=int"` + Enabled bool `config:"type=bool"` + Tags []string `config:"type=array"` +} + +type CustomKeyTestConfig struct { + ServerName string `config:"type=string,required,key=server_name"` + MaxConn int `config:"type=int,key=max_connections"` + EnableLogs bool `config:"type=bool,key=enable_logging"` + Categories []string `config:"type=array,key=category_list"` + Description string `config:"type=string"` // No custom key - should use field name lowercase +} + +func TestSectionConfig_BasicOperations(t *testing.T) { + // Setup + tempDir := t.TempDir() + + // Create plugin with basic config + testPlugin := &SectionPlugin[BasicTestConfig]{ + TypeName: "test", + FolderPath: tempDir, + Validate: nil, + } + config := NewSectionConfig(testPlugin) + + t.Run("Create and Read", func(t *testing.T) { + testFile := filepath.Join(tempDir, utils.EncodePath("test-basic-cr")+".cfg") + testData := &ConfigData[BasicTestConfig]{ + Sections: map[string]*Section[BasicTestConfig]{ + "test-basic-cr": { + Type: "test", + ID: "test-basic-cr", + Properties: BasicTestConfig{ + Name: "Test 1", + Value: "Value 1", + }, + }, + }, + Order: []string{"test-basic-cr"}, + } + + // Write config + err := config.Write(testData) + require.NoError(t, err) + + // Read config + readData, err := config.Parse(testFile) + require.NoError(t, err) + + // Verify data + assert.Equal(t, testData.Order, readData.Order) + assert.Equal(t, testData.Sections["test-basic-cr"].Properties.Name, + readData.Sections["test-basic-cr"].Properties.Name) + }) + + t.Run("Missing Required Property", func(t *testing.T) { + testData := &ConfigData[BasicTestConfig]{ + Sections: map[string]*Section[BasicTestConfig]{ + "test-missing": { + Type: "test", + ID: "test-missing", + Properties: BasicTestConfig{ + Name: "Test 1", + // Missing required Value + }, + }, + }, + Order: []string{"test-missing"}, + } + + err := config.Write(testData) + assert.Error(t, err) + }) +} + +func TestSectionConfig_ArrayProperties(t *testing.T) { + // Setup + tempDir := t.TempDir() + + // Create plugin with array config + arrayPlugin := &SectionPlugin[ArrayTestConfig]{ + TypeName: "array-test", + FolderPath: tempDir, + Validate: nil, + } + config := NewSectionConfig(arrayPlugin) + + t.Run("Array Property Handling", func(t *testing.T) { + testFile := filepath.Join(tempDir, utils.EncodePath("test-array")+".cfg") + testData := &ConfigData[ArrayTestConfig]{ + Sections: map[string]*Section[ArrayTestConfig]{ + "test-array": { + Type: "array-test", + ID: "test-array", + Properties: ArrayTestConfig{ + Tags: []string{"tag1", "tag2", "tag3"}, + }, + }, + }, + Order: []string{"test-array"}, + } + + err := config.Write(testData) + require.NoError(t, err) + + readData, err := config.Parse(testFile) + require.NoError(t, err) + + assert.Equal(t, testData.Sections["test-array"].Properties.Tags, + readData.Sections["test-array"].Properties.Tags) + }) +} + +func TestSectionConfig_ValidationRules(t *testing.T) { + // Setup + tempDir := t.TempDir() + + // Create plugin with validation config + validationPlugin := &SectionPlugin[ValidationTestConfig]{ + TypeName: "validation-test", + FolderPath: tempDir, + Validate: func(c ValidationTestConfig) error { + if len(c.Email) > 254 { + return fmt.Errorf("email too long") + } + return nil + }, + } + config := NewSectionConfig(validationPlugin) + + t.Run("Valid Pattern", func(t *testing.T) { + testData := &ConfigData[ValidationTestConfig]{ + Sections: map[string]*Section[ValidationTestConfig]{ + "test-validate": { + Type: "validation-test", + ID: "test-validate", + Properties: ValidationTestConfig{ + Email: "test@example.com", + }, + }, + }, + Order: []string{"test-validate"}, + } + + err := config.Write(testData) + require.NoError(t, err) + }) + + t.Run("Email Too Long", func(t *testing.T) { + longEmail := "very-long-email" + for i := 0; i < 250; i++ { + longEmail += "x" + } + longEmail += "@example.com" + + testData := &ConfigData[ValidationTestConfig]{ + Sections: map[string]*Section[ValidationTestConfig]{ + "test-long-email": { + Type: "validation-test", + ID: "test-long-email", + Properties: ValidationTestConfig{ + Email: longEmail, + }, + }, + }, + Order: []string{"test-long-email"}, + } + + err := config.Write(testData) + assert.Error(t, err) + }) +} + +// Test edge cases from old format +func TestEdgeCaseCompatibility(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "edge.cfg") + + // Create a config file with edge cases + oldFormatConfig := `test: edge-case + path /test/path + comment + count + enabled false + tags + +` + err := os.WriteFile(configPath, []byte(oldFormatConfig), 0644) + require.NoError(t, err) + + plugin := &SectionPlugin[CompatTestConfig]{ + TypeName: "test", + FolderPath: tempDir, + } + config := NewSectionConfig(plugin) + + // Read and verify edge cases + configData, err := config.Parse(configPath) + require.NoError(t, err) + + section := configData.Sections["edge-case"] + props := section.Properties + + // Verify empty fields are handled correctly + require.Equal(t, "/test/path", props.Path) + require.Equal(t, "", props.Comment) + require.Equal(t, 0, props.Count) + require.False(t, props.Enabled) + require.Empty(t, props.Tags) + + // Write and verify it maintains format + err = config.Write(configData) + require.NoError(t, err) +} + +func TestFormatCompatibility(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + oldConfig string + expectedPath string + expectedCount int + expectedTags []string + expectedOutput string + }{ + { + name: "Tab Indentation", + oldConfig: `test: tab-indent + path /test/path + count 42 + tags tag1,tag2 + +`, + expectedPath: "/test/path", + expectedCount: 42, + expectedTags: []string{"tag1", "tag2"}, + expectedOutput: `test: tab-indent + count 42 + path /test/path + tags tag1,tag2 + +`, + }, + { + name: "Space Indentation", + oldConfig: `test: space-indent + path /test/path + count 42 + tags tag1,tag2 + +`, + expectedPath: "/test/path", + expectedCount: 42, + expectedTags: []string{"tag1", "tag2"}, + expectedOutput: `test: space-indent + count 42 + path /test/path + tags tag1,tag2 + +`, + }, + { + name: "Mixed Whitespace", + oldConfig: `test: mixed-ws + path /test/path + count 42 + tags tag1, tag2 + +`, + expectedPath: "/test/path", + expectedCount: 42, + expectedTags: []string{"tag1", "tag2"}, + expectedOutput: `test: mixed-ws + count 42 + path /test/path + tags tag1,tag2 + +`, + }, + { + name: "Multiple Sections", + oldConfig: `test: section1 + path /path1 + count 1 + +test: section2 + path /path2 + count 2 + +`, + expectedPath: "/path1", + expectedCount: 1, + expectedTags: nil, + expectedOutput: `test: section1 + count 1 + path /path1 + +test: section2 + count 2 + path /path2 + +`, + }, + { + name: "Empty Values", + oldConfig: `test: empty-values + path /test/path + count 0 + tags + enabled false + +`, + expectedPath: "/test/path", + expectedCount: 0, + expectedTags: nil, + expectedOutput: `test: empty-values + path /test/path + tags + +`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + plugin := &SectionPlugin[CompatTestConfig]{ + TypeName: "test", + FolderPath: tempDir, + } + config := NewSectionConfig(plugin) + + configPath := filepath.Join(tempDir, fmt.Sprintf("%s.cfg", tc.name)) + err := os.WriteFile(configPath, []byte(tc.oldConfig), 0644) + require.NoError(t, err) + + configData, err := config.Parse(configPath) + require.NoError(t, err) + + var firstSectionID string + for id := range configData.Sections { + firstSectionID = id + break + } + section := configData.Sections[firstSectionID] + require.NotNil(t, section) + + assert.Equal(t, tc.expectedPath, section.Properties.Path) + assert.Equal(t, tc.expectedCount, section.Properties.Count) + + if tc.expectedTags == nil { + assert.Empty(t, section.Properties.Tags) + } else { + assert.Equal(t, tc.expectedTags, section.Properties.Tags) + } + + err = config.Write(configData) + require.NoError(t, err) + + written, err := os.ReadFile(configPath) + require.NoError(t, err) + assert.Equal(t, tc.expectedOutput, string(written)) + + // Verify the written file can be parsed again + _, err = config.Parse(configPath) + require.NoError(t, err, "Written file should be parseable") + }) + } +} + +// TestCrossImplementationRoundTrip tests that configs can be written by old implementation +// and read by new, and vice versa +func TestCrossImplementationRoundTrip(t *testing.T) { + tempDir := t.TempDir() + err := os.MkdirAll(tempDir, 0750) + require.NoError(t, err) + + configPath := filepath.Join(tempDir, "roundtrip.cfg") + + testConfig := &ConfigData[CompatTestConfig]{ + FilePath: configPath, // Set the filepath explicitly + Sections: map[string]*Section[CompatTestConfig]{ + "test-section": { + Type: "test", + ID: "test-section", + Properties: CompatTestConfig{ + Path: "/complex/path with spaces", + Comment: "Multi word\tcomment with\ttabs", + Count: 42, + Enabled: true, + Tags: []string{"tag1", "tag with space"}, + }, + }, + }, + Order: []string{"test-section"}, + } + + plugin := &SectionPlugin[CompatTestConfig]{ + TypeName: "test", + FolderPath: tempDir, + } + config := NewSectionConfig(plugin) + + // Write with new implementation + err = config.Write(testConfig) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(configPath) + require.NoError(t, err, "Config file should exist") + + // Read the config back + readConfig, err := config.Parse(configPath) + require.NoError(t, err) + + // Verify all fields match exactly + original := testConfig.Sections["test-section"].Properties + read := readConfig.Sections["test-section"].Properties + + assert.Equal(t, original.Path, read.Path) + assert.Equal(t, original.Comment, read.Comment) + assert.Equal(t, original.Count, read.Count) + assert.Equal(t, original.Enabled, read.Enabled) + assert.Equal(t, original.Tags, read.Tags) + + // Verify section order is preserved + assert.Equal(t, testConfig.Order, readConfig.Order) +} + +func TestCustomKeyConfig(t *testing.T) { + tempDir := t.TempDir() + + plugin := &SectionPlugin[CustomKeyTestConfig]{ + TypeName: "server", + FolderPath: tempDir, + } + config := NewSectionConfig(plugin) + + t.Run("Write and Read Custom Keys", func(t *testing.T) { + testFile := filepath.Join(tempDir, utils.EncodePath("test-custom-keys")+".cfg") + + // Create test data with all fields populated + testData := &ConfigData[CustomKeyTestConfig]{ + FilePath: testFile, + Sections: map[string]*Section[CustomKeyTestConfig]{ + "test-custom-keys": { + Type: "server", + ID: "test-custom-keys", + Properties: CustomKeyTestConfig{ + ServerName: "TestServer", + MaxConn: 100, + EnableLogs: true, + Categories: []string{"web", "api", "backend"}, + Description: "Test server configuration", + }, + }, + }, + Order: []string{"test-custom-keys"}, + } + + // Write config + err := config.Write(testData) + require.NoError(t, err) + + // Verify the written file contains custom keys + content, err := os.ReadFile(testFile) + require.NoError(t, err) + + contentStr := string(content) + assert.Contains(t, contentStr, "server_name TestServer") + assert.Contains(t, contentStr, "max_connections 100") + assert.Contains(t, contentStr, "enable_logging true") + assert.Contains(t, contentStr, "category_list web,api,backend") + assert.Contains(t, contentStr, "description Test server configuration") + + // Read config back + readData, err := config.Parse(testFile) + require.NoError(t, err) + + // Verify all fields were correctly read back + readProps := readData.Sections["test-custom-keys"].Properties + assert.Equal(t, "TestServer", readProps.ServerName) + assert.Equal(t, 100, readProps.MaxConn) + assert.True(t, readProps.EnableLogs) + assert.Equal(t, []string{"web", "api", "backend"}, readProps.Categories) + assert.Equal(t, "Test server configuration", readProps.Description) + }) + + t.Run("Partial Fields", func(t *testing.T) { + testFile := filepath.Join(tempDir, utils.EncodePath("test-partial-keys")+".cfg") + + // Create config with only required and some optional fields + testData := &ConfigData[CustomKeyTestConfig]{ + FilePath: testFile, + Sections: map[string]*Section[CustomKeyTestConfig]{ + "test-partial-keys": { + Type: "server", + ID: "test-partial-keys", + Properties: CustomKeyTestConfig{ + ServerName: "MinimalServer", // Required field + MaxConn: 50, // Optional with custom key + }, + }, + }, + Order: []string{"test-partial-keys"}, + } + + // Write config + err := config.Write(testData) + require.NoError(t, err) + + // Read config back + readData, err := config.Parse(testFile) + require.NoError(t, err) + + // Verify fields + readProps := readData.Sections["test-partial-keys"].Properties + assert.Equal(t, "MinimalServer", readProps.ServerName) + assert.Equal(t, 50, readProps.MaxConn) + assert.False(t, readProps.EnableLogs) + assert.Empty(t, readProps.Categories) + assert.Empty(t, readProps.Description) + }) + + t.Run("Missing Required Field", func(t *testing.T) { + testData := &ConfigData[CustomKeyTestConfig]{ + Sections: map[string]*Section[CustomKeyTestConfig]{ + "test-missing-required": { + Type: "server", + ID: "test-missing-required", + Properties: CustomKeyTestConfig{ + // Missing ServerName which is required + MaxConn: 100, + }, + }, + }, + Order: []string{"test-missing-required"}, + } + + // Should fail validation + err := config.Write(testData) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is empty") + }) +} diff --git a/internal/proxy/controllers/agents/agents.go b/internal/proxy/controllers/agents/agents.go index c6adde9..667c05a 100644 --- a/internal/proxy/controllers/agents/agents.go +++ b/internal/proxy/controllers/agents/agents.go @@ -1,241 +1,241 @@ -//go:build linux - -package agents - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "strings" - - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -type LogRequest struct { - Hostname string `json:"hostname"` - Message string `json:"message"` - Level string `json:"level"` -} - -func AgentLogHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - } - - var reqParsed LogRequest - err := json.NewDecoder(r.Body).Decode(&reqParsed) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - switch reqParsed.Level { - case "info": - syslog.L.Infof("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) - case "error": - syslog.L.Errorf("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) - case "warn": - syslog.L.Warnf("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) - default: - syslog.L.Infof("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) - } - - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(map[string]string{"success": "true"}) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - } -} - -type BootstrapRequest struct { - Hostname string `json:"hostname"` - CSR string `json:"csr"` - Drives []utils.DriveInfo `json:"drives"` -} - -func AgentBootstrapHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - } - - authHeader := r.Header.Get("Authorization") - authHeaderSplit := strings.Split(authHeader, " ") - if len(authHeaderSplit) != 2 || authHeaderSplit[0] != "Bearer" { - w.WriteHeader(http.StatusUnauthorized) - controllers.WriteErrorResponse(w, fmt.Errorf("unauthorized bearer access: %s", authHeader)) - return - } - - tokenStr := authHeaderSplit[1] - token, err := storeInstance.Database.GetToken(tokenStr) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - controllers.WriteErrorResponse(w, fmt.Errorf("token not found")) - return - } - - if token.Revoked { - w.WriteHeader(http.StatusUnauthorized) - controllers.WriteErrorResponse(w, fmt.Errorf("token already revoked")) - return - } - - var reqParsed BootstrapRequest - err = json.NewDecoder(r.Body).Decode(&reqParsed) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - decodedCSR, err := base64.StdEncoding.DecodeString(reqParsed.CSR) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - cert, err := storeInstance.CertGenerator.SignCSR(decodedCSR) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - encodedCert := base64.StdEncoding.EncodeToString(cert) - encodedCA := base64.StdEncoding.EncodeToString(storeInstance.CertGenerator.GetCAPEM()) - - clientIP := r.RemoteAddr - - forwarded := r.Header.Get("X-FORWARDED-FOR") - if forwarded != "" { - clientIP = forwarded - } - - clientIP = strings.Split(clientIP, ":")[0] - - for _, drive := range reqParsed.Drives { - newTarget := types.Target{ - Name: fmt.Sprintf("%s - %s", reqParsed.Hostname, drive.Letter), - Path: fmt.Sprintf("agent://%s/%s", clientIP, drive.Letter), - Auth: encodedCert, - TokenUsed: tokenStr, - DriveType: drive.Type, - DriveFS: drive.FileSystem, - DriveFreeBytes: int(drive.FreeBytes), - DriveUsedBytes: int(drive.UsedBytes), - DriveTotalBytes: int(drive.TotalBytes), - DriveFree: drive.Free, - DriveUsed: drive.Used, - DriveTotal: drive.Total, - } - - err := storeInstance.Database.CreateTarget(newTarget) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - } - - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(map[string]string{"ca": encodedCA, "cert": encodedCert}) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - } -} - -func AgentRenewHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - } - - var reqParsed BootstrapRequest - err := json.NewDecoder(r.Body).Decode(&reqParsed) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - existingTarget, err := storeInstance.Database.GetTarget(reqParsed.Hostname + " - C") - if err != nil || existingTarget == nil { - w.WriteHeader(http.StatusNotFound) - controllers.WriteErrorResponse(w, err) - return - } - - decodedCSR, err := base64.StdEncoding.DecodeString(reqParsed.CSR) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - cert, err := storeInstance.CertGenerator.SignCSR(decodedCSR) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - encodedCert := base64.StdEncoding.EncodeToString(cert) - encodedCA := base64.StdEncoding.EncodeToString(storeInstance.CertGenerator.GetCAPEM()) - - clientIP := r.RemoteAddr - - forwarded := r.Header.Get("X-FORWARDED-FOR") - if forwarded != "" { - clientIP = forwarded - } - - clientIP = strings.Split(clientIP, ":")[0] - - for _, drive := range reqParsed.Drives { - newTarget := types.Target{ - Name: fmt.Sprintf("%s - %s", reqParsed.Hostname, drive.Letter), - Path: fmt.Sprintf("agent://%s/%s", clientIP, drive.Letter), - Auth: encodedCert, - TokenUsed: existingTarget.TokenUsed, - DriveType: drive.Type, - DriveFS: drive.FileSystem, - DriveFreeBytes: int(drive.FreeBytes), - DriveUsedBytes: int(drive.UsedBytes), - DriveTotalBytes: int(drive.TotalBytes), - DriveFree: drive.Free, - DriveUsed: drive.Used, - DriveTotal: drive.Total, - } - - err := storeInstance.Database.CreateTarget(newTarget) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - } - - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(map[string]string{"ca": encodedCA, "cert": encodedCert}) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - } -} +//go:build linux + +package agents + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +type LogRequest struct { + Hostname string `json:"hostname"` + Message string `json:"message"` + Level string `json:"level"` +} + +func AgentLogHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + } + + var reqParsed LogRequest + err := json.NewDecoder(r.Body).Decode(&reqParsed) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + switch reqParsed.Level { + case "info": + syslog.L.Infof("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) + case "error": + syslog.L.Errorf("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) + case "warn": + syslog.L.Warnf("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) + default: + syslog.L.Infof("PBS Agent [%s]: %s", reqParsed.Hostname, reqParsed.Message) + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(map[string]string{"success": "true"}) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + } +} + +type BootstrapRequest struct { + Hostname string `json:"hostname"` + CSR string `json:"csr"` + Drives []utils.DriveInfo `json:"drives"` +} + +func AgentBootstrapHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + } + + authHeader := r.Header.Get("Authorization") + authHeaderSplit := strings.Split(authHeader, " ") + if len(authHeaderSplit) != 2 || authHeaderSplit[0] != "Bearer" { + w.WriteHeader(http.StatusUnauthorized) + controllers.WriteErrorResponse(w, fmt.Errorf("unauthorized bearer access: %s", authHeader)) + return + } + + tokenStr := authHeaderSplit[1] + token, err := storeInstance.Database.GetToken(tokenStr) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + controllers.WriteErrorResponse(w, fmt.Errorf("token not found")) + return + } + + if token.Revoked { + w.WriteHeader(http.StatusUnauthorized) + controllers.WriteErrorResponse(w, fmt.Errorf("token already revoked")) + return + } + + var reqParsed BootstrapRequest + err = json.NewDecoder(r.Body).Decode(&reqParsed) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + decodedCSR, err := base64.StdEncoding.DecodeString(reqParsed.CSR) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + cert, err := storeInstance.CertGenerator.SignCSR(decodedCSR) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + encodedCert := base64.StdEncoding.EncodeToString(cert) + encodedCA := base64.StdEncoding.EncodeToString(storeInstance.CertGenerator.GetCAPEM()) + + clientIP := r.RemoteAddr + + forwarded := r.Header.Get("X-FORWARDED-FOR") + if forwarded != "" { + clientIP = forwarded + } + + clientIP = strings.Split(clientIP, ":")[0] + + for _, drive := range reqParsed.Drives { + newTarget := types.Target{ + Name: fmt.Sprintf("%s - %s", reqParsed.Hostname, drive.Letter), + Path: fmt.Sprintf("agent://%s/%s", clientIP, drive.Letter), + Auth: encodedCert, + TokenUsed: tokenStr, + DriveType: drive.Type, + DriveFS: drive.FileSystem, + DriveFreeBytes: int(drive.FreeBytes), + DriveUsedBytes: int(drive.UsedBytes), + DriveTotalBytes: int(drive.TotalBytes), + DriveFree: drive.Free, + DriveUsed: drive.Used, + DriveTotal: drive.Total, + } + + err := storeInstance.Database.CreateTarget(newTarget) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(map[string]string{"ca": encodedCA, "cert": encodedCert}) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + } +} + +func AgentRenewHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + } + + var reqParsed BootstrapRequest + err := json.NewDecoder(r.Body).Decode(&reqParsed) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + existingTarget, err := storeInstance.Database.GetTarget(reqParsed.Hostname + " - C") + if err != nil || existingTarget == nil { + w.WriteHeader(http.StatusNotFound) + controllers.WriteErrorResponse(w, err) + return + } + + decodedCSR, err := base64.StdEncoding.DecodeString(reqParsed.CSR) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + cert, err := storeInstance.CertGenerator.SignCSR(decodedCSR) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + encodedCert := base64.StdEncoding.EncodeToString(cert) + encodedCA := base64.StdEncoding.EncodeToString(storeInstance.CertGenerator.GetCAPEM()) + + clientIP := r.RemoteAddr + + forwarded := r.Header.Get("X-FORWARDED-FOR") + if forwarded != "" { + clientIP = forwarded + } + + clientIP = strings.Split(clientIP, ":")[0] + + for _, drive := range reqParsed.Drives { + newTarget := types.Target{ + Name: fmt.Sprintf("%s - %s", reqParsed.Hostname, drive.Letter), + Path: fmt.Sprintf("agent://%s/%s", clientIP, drive.Letter), + Auth: encodedCert, + TokenUsed: existingTarget.TokenUsed, + DriveType: drive.Type, + DriveFS: drive.FileSystem, + DriveFreeBytes: int(drive.FreeBytes), + DriveUsedBytes: int(drive.UsedBytes), + DriveTotalBytes: int(drive.TotalBytes), + DriveFree: drive.Free, + DriveUsed: drive.Used, + DriveTotal: drive.Total, + } + + err := storeInstance.Database.CreateTarget(newTarget) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(map[string]string{"ca": encodedCA, "cert": encodedCert}) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + } +} diff --git a/internal/proxy/controllers/exclusions/exclusions.go b/internal/proxy/controllers/exclusions/exclusions.go index 3ba966a..2d226b7 100644 --- a/internal/proxy/controllers/exclusions/exclusions.go +++ b/internal/proxy/controllers/exclusions/exclusions.go @@ -1,180 +1,180 @@ -//go:build linux - -package exclusions - -import ( - "encoding/json" - "net/http" - "net/url" - - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -func D2DExclusionHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet && r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - if r.Method == http.MethodGet { - all, err := storeInstance.Database.GetAllGlobalExclusions() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - digest, err := utils.CalculateDigest(all) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - toReturn := ExclusionsResponse{ - Data: all, - Digest: digest, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(toReturn) - - return - } - } -} - -func ExtJsExclusionHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := ExclusionConfigResponse{} - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "application/json") - - err := r.ParseForm() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - newExclusion := types.Exclusion{ - Path: r.FormValue("path"), - Comment: r.FormValue("comment"), - } - - err = storeInstance.Database.CreateExclusion(newExclusion) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - } -} - -func ExtJsExclusionSingleHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := ExclusionConfigResponse{} - if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - } - - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodPut { - err := r.ParseForm() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - pathDecoded, err := url.QueryUnescape(utils.DecodePath(r.PathValue("exclusion"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - exclusion, err := storeInstance.Database.GetExclusion(pathDecoded) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - if r.FormValue("path") != "" { - exclusion.Path = r.FormValue("path") - } - if r.FormValue("comment") != "" { - exclusion.Comment = r.FormValue("comment") - } - - if delArr, ok := r.Form["delete"]; ok { - for _, attr := range delArr { - switch attr { - case "path": - exclusion.Path = "" - case "comment": - exclusion.Comment = "" - } - } - } - - err = storeInstance.Database.UpdateExclusion(*exclusion) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - - return - } - - if r.Method == http.MethodGet { - pathDecoded, err := url.QueryUnescape(utils.DecodePath(r.PathValue("exclusion"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - exclusion, err := storeInstance.Database.GetExclusion(pathDecoded) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - response.Data = exclusion - json.NewEncoder(w).Encode(response) - - return - } - - if r.Method == http.MethodDelete { - pathDecoded, err := url.QueryUnescape(utils.DecodePath(r.PathValue("exclusion"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - err = storeInstance.Database.DeleteExclusion(pathDecoded) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - return - } - } -} +//go:build linux + +package exclusions + +import ( + "encoding/json" + "net/http" + "net/url" + + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +func D2DExclusionHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + if r.Method == http.MethodGet { + all, err := storeInstance.Database.GetAllGlobalExclusions() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + digest, err := utils.CalculateDigest(all) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + toReturn := ExclusionsResponse{ + Data: all, + Digest: digest, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toReturn) + + return + } + } +} + +func ExtJsExclusionHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := ExclusionConfigResponse{} + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + err := r.ParseForm() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + newExclusion := types.Exclusion{ + Path: r.FormValue("path"), + Comment: r.FormValue("comment"), + } + + err = storeInstance.Database.CreateExclusion(newExclusion) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + } +} + +func ExtJsExclusionSingleHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := ExclusionConfigResponse{} + if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + } + + w.Header().Set("Content-Type", "application/json") + + if r.Method == http.MethodPut { + err := r.ParseForm() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + pathDecoded, err := url.QueryUnescape(utils.DecodePath(r.PathValue("exclusion"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + exclusion, err := storeInstance.Database.GetExclusion(pathDecoded) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + if r.FormValue("path") != "" { + exclusion.Path = r.FormValue("path") + } + if r.FormValue("comment") != "" { + exclusion.Comment = r.FormValue("comment") + } + + if delArr, ok := r.Form["delete"]; ok { + for _, attr := range delArr { + switch attr { + case "path": + exclusion.Path = "" + case "comment": + exclusion.Comment = "" + } + } + } + + err = storeInstance.Database.UpdateExclusion(*exclusion) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + + return + } + + if r.Method == http.MethodGet { + pathDecoded, err := url.QueryUnescape(utils.DecodePath(r.PathValue("exclusion"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + exclusion, err := storeInstance.Database.GetExclusion(pathDecoded) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + response.Data = exclusion + json.NewEncoder(w).Encode(response) + + return + } + + if r.Method == http.MethodDelete { + pathDecoded, err := url.QueryUnescape(utils.DecodePath(r.PathValue("exclusion"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + err = storeInstance.Database.DeleteExclusion(pathDecoded) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + return + } + } +} diff --git a/internal/proxy/controllers/exclusions/types.go b/internal/proxy/controllers/exclusions/types.go index e77e9c5..6321beb 100644 --- a/internal/proxy/controllers/exclusions/types.go +++ b/internal/proxy/controllers/exclusions/types.go @@ -1,20 +1,20 @@ -//go:build linux - -package exclusions - -import ( - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" -) - -type ExclusionsResponse struct { - Data []types.Exclusion `json:"data"` - Digest string `json:"digest"` -} - -type ExclusionConfigResponse struct { - Errors map[string]string `json:"errors"` - Message string `json:"message"` - Data *types.Exclusion `json:"data"` - Status int `json:"status"` - Success bool `json:"success"` -} +//go:build linux + +package exclusions + +import ( + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" +) + +type ExclusionsResponse struct { + Data []types.Exclusion `json:"data"` + Digest string `json:"digest"` +} + +type ExclusionConfigResponse struct { + Errors map[string]string `json:"errors"` + Message string `json:"message"` + Data *types.Exclusion `json:"data"` + Status int `json:"status"` + Success bool `json:"success"` +} diff --git a/internal/proxy/controllers/jobs/jobs.go b/internal/proxy/controllers/jobs/jobs.go index c008382..118d83d 100644 --- a/internal/proxy/controllers/jobs/jobs.go +++ b/internal/proxy/controllers/jobs/jobs.go @@ -1,260 +1,260 @@ -//go:build linux - -package jobs - -import ( - "encoding/json" - "net/http" - "strings" - - "github.com/sonroyaalmerol/pbs-plus/internal/backend/backup" - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -func D2DJobHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - allJobs, err := storeInstance.Database.GetAllJobs() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - digest, err := utils.CalculateDigest(allJobs) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - toReturn := JobsResponse{ - Data: allJobs, - Digest: digest, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(toReturn) - } -} - -func ExtJsJobRunHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := JobRunResponse{} - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - job, err := storeInstance.Database.GetJob(utils.DecodePath(r.PathValue("job"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - op, err := backup.RunBackup(job, storeInstance, false) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - task := op.Task - - w.Header().Set("Content-Type", "application/json") - - response.Data = task.UPID - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - } -} - -func ExtJsJobHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := JobConfigResponse{} - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "application/json") - - err := r.ParseForm() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - newJob := types.Job{ - ID: r.FormValue("id"), - Store: r.FormValue("store"), - Target: r.FormValue("target"), - Subpath: r.FormValue("subpath"), - Schedule: r.FormValue("schedule"), - Comment: r.FormValue("comment"), - Namespace: r.FormValue("ns"), - NotificationMode: r.FormValue("notification-mode"), - Exclusions: []types.Exclusion{}, - } - - rawExclusions := r.FormValue("rawexclusions") - for _, exclusion := range strings.Split(rawExclusions, "\n") { - exclusion = strings.TrimSpace(exclusion) - if exclusion == "" { - continue - } - - exclusionInst := types.Exclusion{ - Path: exclusion, - JobID: newJob.ID, - } - - newJob.Exclusions = append(newJob.Exclusions, exclusionInst) - } - - err = storeInstance.Database.CreateJob(newJob) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - } -} - -func ExtJsJobSingleHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := JobConfigResponse{} - if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodPut { - job, err := storeInstance.Database.GetJob(utils.DecodePath(r.PathValue("job"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - err = r.ParseForm() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - if r.FormValue("store") != "" { - job.Store = r.FormValue("store") - } - if r.FormValue("target") != "" { - job.Target = r.FormValue("target") - } - if r.FormValue("subpath") != "" { - job.Subpath = r.FormValue("subpath") - } - if r.FormValue("schedule") != "" { - job.Schedule = r.FormValue("schedule") - } - if r.FormValue("comment") != "" { - job.Comment = r.FormValue("comment") - } - if r.FormValue("ns") != "" { - job.Namespace = r.FormValue("ns") - } - if r.FormValue("notification-mode") != "" { - job.NotificationMode = r.FormValue("notification-mode") - } - - if r.FormValue("rawexclusions") != "" { - job.Exclusions = []types.Exclusion{} - - rawExclusions := r.FormValue("rawexclusions") - for _, exclusion := range strings.Split(rawExclusions, "\n") { - exclusion = strings.TrimSpace(exclusion) - if exclusion == "" { - continue - } - - exclusionInst := types.Exclusion{ - Path: exclusion, - JobID: job.ID, - } - - job.Exclusions = append(job.Exclusions, exclusionInst) - } - } - - if delArr, ok := r.Form["delete"]; ok { - for _, attr := range delArr { - switch attr { - case "store": - job.Store = "" - case "target": - job.Target = "" - case "subpath": - job.Subpath = "" - case "schedule": - job.Schedule = "" - case "comment": - job.Comment = "" - case "ns": - job.Namespace = "" - case "notification-mode": - job.NotificationMode = "" - case "rawexclusions": - job.Exclusions = []types.Exclusion{} - } - } - } - - err = storeInstance.Database.UpdateJob(*job) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - - return - } - - if r.Method == http.MethodGet { - job, err := storeInstance.Database.GetJob(utils.DecodePath(r.PathValue("job"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - response.Data = job - json.NewEncoder(w).Encode(response) - - return - } - - if r.Method == http.MethodDelete { - err := storeInstance.Database.DeleteJob(utils.DecodePath(r.PathValue("job"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - return - } - } -} +//go:build linux + +package jobs + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/sonroyaalmerol/pbs-plus/internal/backend/backup" + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +func D2DJobHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + allJobs, err := storeInstance.Database.GetAllJobs() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + digest, err := utils.CalculateDigest(allJobs) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + toReturn := JobsResponse{ + Data: allJobs, + Digest: digest, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toReturn) + } +} + +func ExtJsJobRunHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := JobRunResponse{} + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + job, err := storeInstance.Database.GetJob(utils.DecodePath(r.PathValue("job"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + op, err := backup.RunBackup(job, storeInstance, false) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + task := op.Task + + w.Header().Set("Content-Type", "application/json") + + response.Data = task.UPID + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + } +} + +func ExtJsJobHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := JobConfigResponse{} + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + err := r.ParseForm() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + newJob := types.Job{ + ID: r.FormValue("id"), + Store: r.FormValue("store"), + Target: r.FormValue("target"), + Subpath: r.FormValue("subpath"), + Schedule: r.FormValue("schedule"), + Comment: r.FormValue("comment"), + Namespace: r.FormValue("ns"), + NotificationMode: r.FormValue("notification-mode"), + Exclusions: []types.Exclusion{}, + } + + rawExclusions := r.FormValue("rawexclusions") + for _, exclusion := range strings.Split(rawExclusions, "\n") { + exclusion = strings.TrimSpace(exclusion) + if exclusion == "" { + continue + } + + exclusionInst := types.Exclusion{ + Path: exclusion, + JobID: newJob.ID, + } + + newJob.Exclusions = append(newJob.Exclusions, exclusionInst) + } + + err = storeInstance.Database.CreateJob(newJob) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + } +} + +func ExtJsJobSingleHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := JobConfigResponse{} + if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + if r.Method == http.MethodPut { + job, err := storeInstance.Database.GetJob(utils.DecodePath(r.PathValue("job"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + err = r.ParseForm() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + if r.FormValue("store") != "" { + job.Store = r.FormValue("store") + } + if r.FormValue("target") != "" { + job.Target = r.FormValue("target") + } + if r.FormValue("subpath") != "" { + job.Subpath = r.FormValue("subpath") + } + if r.FormValue("schedule") != "" { + job.Schedule = r.FormValue("schedule") + } + if r.FormValue("comment") != "" { + job.Comment = r.FormValue("comment") + } + if r.FormValue("ns") != "" { + job.Namespace = r.FormValue("ns") + } + if r.FormValue("notification-mode") != "" { + job.NotificationMode = r.FormValue("notification-mode") + } + + if r.FormValue("rawexclusions") != "" { + job.Exclusions = []types.Exclusion{} + + rawExclusions := r.FormValue("rawexclusions") + for _, exclusion := range strings.Split(rawExclusions, "\n") { + exclusion = strings.TrimSpace(exclusion) + if exclusion == "" { + continue + } + + exclusionInst := types.Exclusion{ + Path: exclusion, + JobID: job.ID, + } + + job.Exclusions = append(job.Exclusions, exclusionInst) + } + } + + if delArr, ok := r.Form["delete"]; ok { + for _, attr := range delArr { + switch attr { + case "store": + job.Store = "" + case "target": + job.Target = "" + case "subpath": + job.Subpath = "" + case "schedule": + job.Schedule = "" + case "comment": + job.Comment = "" + case "ns": + job.Namespace = "" + case "notification-mode": + job.NotificationMode = "" + case "rawexclusions": + job.Exclusions = []types.Exclusion{} + } + } + } + + err = storeInstance.Database.UpdateJob(*job) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + + return + } + + if r.Method == http.MethodGet { + job, err := storeInstance.Database.GetJob(utils.DecodePath(r.PathValue("job"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + response.Data = job + json.NewEncoder(w).Encode(response) + + return + } + + if r.Method == http.MethodDelete { + err := storeInstance.Database.DeleteJob(utils.DecodePath(r.PathValue("job"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + return + } + } +} diff --git a/internal/proxy/controllers/jobs/types.go b/internal/proxy/controllers/jobs/types.go index 3f54553..8b620ce 100644 --- a/internal/proxy/controllers/jobs/types.go +++ b/internal/proxy/controllers/jobs/types.go @@ -1,28 +1,28 @@ -//go:build linux - -package jobs - -import ( - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" -) - -type JobsResponse struct { - Data []types.Job `json:"data"` - Digest string `json:"digest"` -} - -type JobConfigResponse struct { - Errors map[string]string `json:"errors"` - Message string `json:"message"` - Data *types.Job `json:"data"` - Status int `json:"status"` - Success bool `json:"success"` -} - -type JobRunResponse struct { - Errors map[string]string `json:"errors"` - Message string `json:"message"` - Data string `json:"data"` - Status int `json:"status"` - Success bool `json:"success"` -} +//go:build linux + +package jobs + +import ( + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" +) + +type JobsResponse struct { + Data []types.Job `json:"data"` + Digest string `json:"digest"` +} + +type JobConfigResponse struct { + Errors map[string]string `json:"errors"` + Message string `json:"message"` + Data *types.Job `json:"data"` + Status int `json:"status"` + Success bool `json:"success"` +} + +type JobRunResponse struct { + Errors map[string]string `json:"errors"` + Message string `json:"message"` + Data string `json:"data"` + Status int `json:"status"` + Success bool `json:"success"` +} diff --git a/internal/proxy/controllers/plus/plus.go b/internal/proxy/controllers/plus/plus.go index 0dcec92..e8ac5eb 100644 --- a/internal/proxy/controllers/plus/plus.go +++ b/internal/proxy/controllers/plus/plus.go @@ -1,207 +1,207 @@ -//go:build linux - -package plus - -import ( - "context" - "encoding/base32" - "encoding/json" - "fmt" - "io" - "net/http" - "time" - - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" - "github.com/sonroyaalmerol/pbs-plus/internal/websockets" -) - -func MountHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost && r.Method != http.MethodDelete { - http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) - return - } - - // TODO: add check for security - - targetHostnameEnc := utils.DecodePath(r.PathValue("target")) - agentDriveEnc := utils.DecodePath(r.PathValue("drive")) - - targetHostnameBytes, err := base32.StdEncoding.DecodeString(targetHostnameEnc) - if err != nil { - http.Error(w, "invalid arguments", http.StatusBadRequest) - return - } - - agentDriveBytes, err := base32.StdEncoding.DecodeString(agentDriveEnc) - if err != nil { - http.Error(w, "invalid arguments", http.StatusBadRequest) - return - } - - targetHostname := string(targetHostnameBytes) - agentDrive := string(agentDriveBytes) - - if r.Method == http.MethodPost { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - // Create response channel and register handler - respChan := make(chan *websockets.Message, 1) - cleanup := storeInstance.WSHub.RegisterHandler("response-backup_start", func(handlerCtx context.Context, msg *websockets.Message) error { - if msg.Content == "Acknowledged: "+agentDrive { - respChan <- msg - } - return nil - }) - defer cleanup() - - // Send initial message - err := storeInstance.WSHub.SendToClient(targetHostname, websockets.Message{ - Type: "backup_start", - Content: agentDrive, - }) - if err != nil { - http.Error(w, fmt.Sprintf("MountHandler: Failed to send backup request to target -> %v", err), http.StatusInternalServerError) - return - } - - // Wait for either response or timeout - select { - case <-respChan: - case <-ctx.Done(): - http.Error(w, "MountHandler: Timeout waiting for backup acknowledgement from target", http.StatusInternalServerError) - return - } - - // Handle successful response - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]string{"status": "true"}); err != nil { - http.Error(w, fmt.Sprintf("MountHandler: Failed to encode response -> %v", err), http.StatusInternalServerError) - return - } - } - - if r.Method == http.MethodDelete { - _ = storeInstance.WSHub.SendToClient(targetHostname, websockets.Message{ - Type: "backup_close", - Content: agentDrive, - }) - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]string{"status": "true"}) - - return - } - } -} - -func VersionHandler(storeInstance *store.Store, version string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) - return - } - - toReturn := VersionResponse{ - Version: version, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(toReturn) - } -} - -func DownloadBinary(storeInstance *store.Store, version string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) - return - } - - // Construct the passthrough URL - baseURL := "https://github.com/sonroyaalmerol/pbs-plus/releases/download/" - targetURL := fmt.Sprintf("%s%s/pbs-plus-agent-%s-windows-amd64.exe", baseURL, version, version) - - // Proxy the request - req, err := http.NewRequest(http.MethodGet, targetURL, nil) - if err != nil { - http.Error(w, "failed to create proxy request", http.StatusInternalServerError) - return - } - - // Copy headers from the original request to the proxy request - copyHeaders(r.Header, req.Header) - - // Perform the request - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - http.Error(w, "failed to fetch binary", http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - // Copy headers from the upstream response to the client response - copyHeaders(resp.Header, w.Header()) - - // Set the status code and copy the body - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(w, resp.Body); err != nil { - http.Error(w, "failed to write response body", http.StatusInternalServerError) - return - } - } -} - -func DownloadChecksum(storeInstance *store.Store, version string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) - return - } - - // Construct the passthrough URL - baseURL := "https://github.com/sonroyaalmerol/pbs-plus/releases/download/" - targetURL := fmt.Sprintf("%s%s/pbs-plus-agent-%s-windows-amd64.exe.md5", baseURL, version, version) - - // Proxy the request - req, err := http.NewRequest(http.MethodGet, targetURL, nil) - if err != nil { - http.Error(w, "failed to create proxy request", http.StatusInternalServerError) - return - } - - // Copy headers from the original request to the proxy request - copyHeaders(r.Header, req.Header) - - // Perform the request - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - http.Error(w, "failed to fetch checksum", http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - // Copy headers from the upstream response to the client response - copyHeaders(resp.Header, w.Header()) - - // Set the status code and copy the body - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(w, resp.Body); err != nil { - http.Error(w, "failed to write response body", http.StatusInternalServerError) - return - } - } -} - -// copyHeaders is a helper function to copy headers from one Header map to another -func copyHeaders(src, dst http.Header) { - for name, values := range src { - for _, value := range values { - dst.Add(name, value) - } - } -} +//go:build linux + +package plus + +import ( + "context" + "encoding/base32" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" + "github.com/sonroyaalmerol/pbs-plus/internal/websockets" +) + +func MountHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost && r.Method != http.MethodDelete { + http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) + return + } + + // TODO: add check for security + + targetHostnameEnc := utils.DecodePath(r.PathValue("target")) + agentDriveEnc := utils.DecodePath(r.PathValue("drive")) + + targetHostnameBytes, err := base32.StdEncoding.DecodeString(targetHostnameEnc) + if err != nil { + http.Error(w, "invalid arguments", http.StatusBadRequest) + return + } + + agentDriveBytes, err := base32.StdEncoding.DecodeString(agentDriveEnc) + if err != nil { + http.Error(w, "invalid arguments", http.StatusBadRequest) + return + } + + targetHostname := string(targetHostnameBytes) + agentDrive := string(agentDriveBytes) + + if r.Method == http.MethodPost { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + // Create response channel and register handler + respChan := make(chan *websockets.Message, 1) + cleanup := storeInstance.WSHub.RegisterHandler("response-backup_start", func(handlerCtx context.Context, msg *websockets.Message) error { + if msg.Content == "Acknowledged: "+agentDrive { + respChan <- msg + } + return nil + }) + defer cleanup() + + // Send initial message + err := storeInstance.WSHub.SendToClient(targetHostname, websockets.Message{ + Type: "backup_start", + Content: agentDrive, + }) + if err != nil { + http.Error(w, fmt.Sprintf("MountHandler: Failed to send backup request to target -> %v", err), http.StatusInternalServerError) + return + } + + // Wait for either response or timeout + select { + case <-respChan: + case <-ctx.Done(): + http.Error(w, "MountHandler: Timeout waiting for backup acknowledgement from target", http.StatusInternalServerError) + return + } + + // Handle successful response + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]string{"status": "true"}); err != nil { + http.Error(w, fmt.Sprintf("MountHandler: Failed to encode response -> %v", err), http.StatusInternalServerError) + return + } + } + + if r.Method == http.MethodDelete { + _ = storeInstance.WSHub.SendToClient(targetHostname, websockets.Message{ + Type: "backup_close", + Content: agentDrive, + }) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{"status": "true"}) + + return + } + } +} + +func VersionHandler(storeInstance *store.Store, version string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) + return + } + + toReturn := VersionResponse{ + Version: version, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toReturn) + } +} + +func DownloadBinary(storeInstance *store.Store, version string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) + return + } + + // Construct the passthrough URL + baseURL := "https://github.com/sonroyaalmerol/pbs-plus/releases/download/" + targetURL := fmt.Sprintf("%s%s/pbs-plus-agent-%s-windows-amd64.exe", baseURL, version, version) + + // Proxy the request + req, err := http.NewRequest(http.MethodGet, targetURL, nil) + if err != nil { + http.Error(w, "failed to create proxy request", http.StatusInternalServerError) + return + } + + // Copy headers from the original request to the proxy request + copyHeaders(r.Header, req.Header) + + // Perform the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + http.Error(w, "failed to fetch binary", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Copy headers from the upstream response to the client response + copyHeaders(resp.Header, w.Header()) + + // Set the status code and copy the body + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + http.Error(w, "failed to write response body", http.StatusInternalServerError) + return + } + } +} + +func DownloadChecksum(storeInstance *store.Store, version string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) + return + } + + // Construct the passthrough URL + baseURL := "https://github.com/sonroyaalmerol/pbs-plus/releases/download/" + targetURL := fmt.Sprintf("%s%s/pbs-plus-agent-%s-windows-amd64.exe.md5", baseURL, version, version) + + // Proxy the request + req, err := http.NewRequest(http.MethodGet, targetURL, nil) + if err != nil { + http.Error(w, "failed to create proxy request", http.StatusInternalServerError) + return + } + + // Copy headers from the original request to the proxy request + copyHeaders(r.Header, req.Header) + + // Perform the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + http.Error(w, "failed to fetch checksum", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Copy headers from the upstream response to the client response + copyHeaders(resp.Header, w.Header()) + + // Set the status code and copy the body + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + http.Error(w, "failed to write response body", http.StatusInternalServerError) + return + } + } +} + +// copyHeaders is a helper function to copy headers from one Header map to another +func copyHeaders(src, dst http.Header) { + for name, values := range src { + for _, value := range values { + dst.Add(name, value) + } + } +} diff --git a/internal/proxy/controllers/plus/req_token.go b/internal/proxy/controllers/plus/req_token.go index 107d459..639cf2e 100644 --- a/internal/proxy/controllers/plus/req_token.go +++ b/internal/proxy/controllers/plus/req_token.go @@ -1,66 +1,66 @@ -//go:build linux - -package plus - -import ( - "encoding/json" - "fmt" - "net/http" - "path/filepath" - "strings" - - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" - "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -type TokenRequest struct { - PBSAuthCookie string `json:"pbs_auth_cookie"` -} - -func TokenHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) - return - } - - tokenReq := &TokenRequest{} - err := json.NewDecoder(r.Body).Decode(tokenReq) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - } - - decodedAuthCookie := strings.ReplaceAll(tokenReq.PBSAuthCookie, "%3A", ":") - cookieSplit := strings.Split(decodedAuthCookie, ":") - if len(cookieSplit) < 5 { - http.Error(w, "ExtractTokenFromRequest: error invalid cookie, less than 5 split", http.StatusBadRequest) - return - } - - token := proxmox.Token{} - - token.CSRFToken = r.Header.Get("csrfpreventiontoken") - token.Ticket = decodedAuthCookie - token.Username = cookieSplit[1] - - proxmox.Session.LastToken = &token - - if !utils.IsValid(filepath.Join(constants.DbBasePath, "pbs-plus-token.json")) { - apiToken, err := proxmox.Session.CreateAPIToken() - if err != nil { - http.Error(w, fmt.Sprintf("ExtractTokenFromRequest: error creating API token -> %v", err), http.StatusInternalServerError) - return - } - - proxmox.Session.APIToken = apiToken - - err = apiToken.SaveToFile() - if err != nil { - http.Error(w, fmt.Sprintf("ExtractTokenFromRequest: error saving API token to file -> %v", err), http.StatusInternalServerError) - return - } - } - } -} +//go:build linux + +package plus + +import ( + "encoding/json" + "fmt" + "net/http" + "path/filepath" + "strings" + + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" + "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +type TokenRequest struct { + PBSAuthCookie string `json:"pbs_auth_cookie"` +} + +func TokenHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusMethodNotAllowed) + return + } + + tokenReq := &TokenRequest{} + err := json.NewDecoder(r.Body).Decode(tokenReq) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + + decodedAuthCookie := strings.ReplaceAll(tokenReq.PBSAuthCookie, "%3A", ":") + cookieSplit := strings.Split(decodedAuthCookie, ":") + if len(cookieSplit) < 5 { + http.Error(w, "ExtractTokenFromRequest: error invalid cookie, less than 5 split", http.StatusBadRequest) + return + } + + token := proxmox.Token{} + + token.CSRFToken = r.Header.Get("csrfpreventiontoken") + token.Ticket = decodedAuthCookie + token.Username = cookieSplit[1] + + proxmox.Session.LastToken = &token + + if !utils.IsValid(filepath.Join(constants.DbBasePath, "pbs-plus-token.json")) { + apiToken, err := proxmox.Session.CreateAPIToken() + if err != nil { + http.Error(w, fmt.Sprintf("ExtractTokenFromRequest: error creating API token -> %v", err), http.StatusInternalServerError) + return + } + + proxmox.Session.APIToken = apiToken + + err = apiToken.SaveToFile() + if err != nil { + http.Error(w, fmt.Sprintf("ExtractTokenFromRequest: error saving API token to file -> %v", err), http.StatusInternalServerError) + return + } + } + } +} diff --git a/internal/proxy/controllers/plus/types.go b/internal/proxy/controllers/plus/types.go index a1c1306..b9d0e35 100644 --- a/internal/proxy/controllers/plus/types.go +++ b/internal/proxy/controllers/plus/types.go @@ -1,7 +1,7 @@ -//go:build linux - -package plus - -type VersionResponse struct { - Version string `json:"version"` -} +//go:build linux + +package plus + +type VersionResponse struct { + Version string `json:"version"` +} diff --git a/internal/proxy/controllers/plus/ws.go b/internal/proxy/controllers/plus/ws.go index b646f25..07e7f9a 100644 --- a/internal/proxy/controllers/plus/ws.go +++ b/internal/proxy/controllers/plus/ws.go @@ -1,20 +1,20 @@ -//go:build linux - -package plus - -import ( - "net/http" - - "github.com/sonroyaalmerol/pbs-plus/internal/store" -) - -func WSHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - storeInstance.WSHub.ServeWS(w, r) - } -} +//go:build linux + +package plus + +import ( + "net/http" + + "github.com/sonroyaalmerol/pbs-plus/internal/store" +) + +func WSHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + storeInstance.WSHub.ServeWS(w, r) + } +} diff --git a/internal/proxy/controllers/targets/targets.go b/internal/proxy/controllers/targets/targets.go index 5ccfd84..82309f8 100644 --- a/internal/proxy/controllers/targets/targets.go +++ b/internal/proxy/controllers/targets/targets.go @@ -1,273 +1,273 @@ -//go:build linux - -package targets - -import ( - "encoding/json" - "fmt" - "net/http" - "slices" - "strings" - - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -func D2DTargetHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - all, err := storeInstance.Database.GetAllTargets() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - for i, target := range all { - if target.IsAgent { - all[i].ConnectionStatus = storeInstance.WSHub.AgentPing(&target) - all[i].AgentVersion = storeInstance.WSHub.AgentVersion(&target) - } - } - - digest, err := utils.CalculateDigest(all) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - toReturn := TargetsResponse{ - Data: all, - Digest: digest, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(toReturn) - - return - } -} - -type NewAgentHostnameRequest struct { - Hostname string `json:"hostname"` - Drives []utils.DriveInfo `json:"drives"` -} - -func D2DTargetAgentHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - var reqParsed NewAgentHostnameRequest - err := json.NewDecoder(r.Body).Decode(&reqParsed) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - clientIP := r.RemoteAddr - - forwarded := r.Header.Get("X-FORWARDED-FOR") - if forwarded != "" { - clientIP = forwarded - } - - clientIP = strings.Split(clientIP, ":")[0] - - existingTargets, err := storeInstance.Database.GetAllTargetsByIP(clientIP) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - - if len(existingTargets) == 0 { - w.WriteHeader(http.StatusNotFound) - controllers.WriteErrorResponse(w, fmt.Errorf("No targets found.")) - return - } - - targetTemplate := existingTargets[0] - - hostname := r.Header.Get("X-PBS-Agent") - - var driveLetters = make([]string, len(reqParsed.Drives)) - for i, parsedDrive := range reqParsed.Drives { - driveLetters[i] = parsedDrive.Letter - - _ = storeInstance.Database.CreateTarget(types.Target{ - Name: hostname + " - " + parsedDrive.Letter, - Path: "agent://" + clientIP + "/" + parsedDrive.Letter, - Auth: targetTemplate.Auth, - TokenUsed: targetTemplate.TokenUsed, - DriveType: parsedDrive.Type, - DriveName: parsedDrive.VolumeName, - DriveFS: parsedDrive.FileSystem, - DriveFreeBytes: int(parsedDrive.FreeBytes), - DriveUsedBytes: int(parsedDrive.UsedBytes), - DriveTotalBytes: int(parsedDrive.TotalBytes), - DriveFree: parsedDrive.Free, - DriveUsed: parsedDrive.Used, - DriveTotal: parsedDrive.Total, - }) - } - - for _, target := range existingTargets { - targetDrive := strings.Split(target.Path, "/")[3] - if !slices.Contains(driveLetters, targetDrive) { - _ = storeInstance.Database.DeleteTarget(target.Name) - } - } - - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(map[string]bool{ - "success": true, - }) - - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - controllers.WriteErrorResponse(w, err) - return - } - } -} - -func ExtJsTargetHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := TargetConfigResponse{} - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "application/json") - - err := r.ParseForm() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - if !utils.IsValid(r.FormValue("path")) { - controllers.WriteErrorResponse(w, fmt.Errorf("invalid path '%s'", r.FormValue("path"))) - return - } - - newTarget := types.Target{ - Name: r.FormValue("name"), - Path: r.FormValue("path"), - } - - err = storeInstance.Database.CreateTarget(newTarget) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - } -} - -func ExtJsTargetSingleHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := TargetConfigResponse{} - if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodPut { - err := r.ParseForm() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - if !utils.IsValid(r.FormValue("path")) { - controllers.WriteErrorResponse(w, fmt.Errorf("invalid path '%s'", r.FormValue("path"))) - return - } - - target, err := storeInstance.Database.GetTarget(utils.DecodePath(r.PathValue("target"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - if r.FormValue("name") != "" { - target.Name = r.FormValue("name") - } - if r.FormValue("path") != "" { - target.Path = r.FormValue("path") - } - - if delArr, ok := r.Form["delete"]; ok { - for _, attr := range delArr { - switch attr { - case "name": - target.Name = "" - case "path": - target.Path = "" - } - } - } - - err = storeInstance.Database.UpdateTarget(*target) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - - return - } - - if r.Method == http.MethodGet { - target, err := storeInstance.Database.GetTarget(utils.DecodePath(r.PathValue("target"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - if target.IsAgent { - target.ConnectionStatus = storeInstance.WSHub.AgentPing(target) - target.AgentVersion = storeInstance.WSHub.AgentVersion(target) - } - - response.Status = http.StatusOK - response.Success = true - response.Data = target - json.NewEncoder(w).Encode(response) - - return - } - - if r.Method == http.MethodDelete { - err := storeInstance.Database.DeleteTarget(utils.DecodePath(r.PathValue("target"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - return - } - } -} +//go:build linux + +package targets + +import ( + "encoding/json" + "fmt" + "net/http" + "slices" + "strings" + + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +func D2DTargetHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + all, err := storeInstance.Database.GetAllTargets() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + for i, target := range all { + if target.IsAgent { + all[i].ConnectionStatus = storeInstance.WSHub.AgentPing(&target) + all[i].AgentVersion = storeInstance.WSHub.AgentVersion(&target) + } + } + + digest, err := utils.CalculateDigest(all) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + toReturn := TargetsResponse{ + Data: all, + Digest: digest, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toReturn) + + return + } +} + +type NewAgentHostnameRequest struct { + Hostname string `json:"hostname"` + Drives []utils.DriveInfo `json:"drives"` +} + +func D2DTargetAgentHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + var reqParsed NewAgentHostnameRequest + err := json.NewDecoder(r.Body).Decode(&reqParsed) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + clientIP := r.RemoteAddr + + forwarded := r.Header.Get("X-FORWARDED-FOR") + if forwarded != "" { + clientIP = forwarded + } + + clientIP = strings.Split(clientIP, ":")[0] + + existingTargets, err := storeInstance.Database.GetAllTargetsByIP(clientIP) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + + if len(existingTargets) == 0 { + w.WriteHeader(http.StatusNotFound) + controllers.WriteErrorResponse(w, fmt.Errorf("No targets found.")) + return + } + + targetTemplate := existingTargets[0] + + hostname := r.Header.Get("X-PBS-Agent") + + var driveLetters = make([]string, len(reqParsed.Drives)) + for i, parsedDrive := range reqParsed.Drives { + driveLetters[i] = parsedDrive.Letter + + _ = storeInstance.Database.CreateTarget(types.Target{ + Name: hostname + " - " + parsedDrive.Letter, + Path: "agent://" + clientIP + "/" + parsedDrive.Letter, + Auth: targetTemplate.Auth, + TokenUsed: targetTemplate.TokenUsed, + DriveType: parsedDrive.Type, + DriveName: parsedDrive.VolumeName, + DriveFS: parsedDrive.FileSystem, + DriveFreeBytes: int(parsedDrive.FreeBytes), + DriveUsedBytes: int(parsedDrive.UsedBytes), + DriveTotalBytes: int(parsedDrive.TotalBytes), + DriveFree: parsedDrive.Free, + DriveUsed: parsedDrive.Used, + DriveTotal: parsedDrive.Total, + }) + } + + for _, target := range existingTargets { + targetDrive := strings.Split(target.Path, "/")[3] + if !slices.Contains(driveLetters, targetDrive) { + _ = storeInstance.Database.DeleteTarget(target.Name) + } + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(map[string]bool{ + "success": true, + }) + + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + controllers.WriteErrorResponse(w, err) + return + } + } +} + +func ExtJsTargetHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := TargetConfigResponse{} + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + err := r.ParseForm() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + if !utils.IsValid(r.FormValue("path")) { + controllers.WriteErrorResponse(w, fmt.Errorf("invalid path '%s'", r.FormValue("path"))) + return + } + + newTarget := types.Target{ + Name: r.FormValue("name"), + Path: r.FormValue("path"), + } + + err = storeInstance.Database.CreateTarget(newTarget) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + } +} + +func ExtJsTargetSingleHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := TargetConfigResponse{} + if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + if r.Method == http.MethodPut { + err := r.ParseForm() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + if !utils.IsValid(r.FormValue("path")) { + controllers.WriteErrorResponse(w, fmt.Errorf("invalid path '%s'", r.FormValue("path"))) + return + } + + target, err := storeInstance.Database.GetTarget(utils.DecodePath(r.PathValue("target"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + if r.FormValue("name") != "" { + target.Name = r.FormValue("name") + } + if r.FormValue("path") != "" { + target.Path = r.FormValue("path") + } + + if delArr, ok := r.Form["delete"]; ok { + for _, attr := range delArr { + switch attr { + case "name": + target.Name = "" + case "path": + target.Path = "" + } + } + } + + err = storeInstance.Database.UpdateTarget(*target) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + + return + } + + if r.Method == http.MethodGet { + target, err := storeInstance.Database.GetTarget(utils.DecodePath(r.PathValue("target"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + if target.IsAgent { + target.ConnectionStatus = storeInstance.WSHub.AgentPing(target) + target.AgentVersion = storeInstance.WSHub.AgentVersion(target) + } + + response.Status = http.StatusOK + response.Success = true + response.Data = target + json.NewEncoder(w).Encode(response) + + return + } + + if r.Method == http.MethodDelete { + err := storeInstance.Database.DeleteTarget(utils.DecodePath(r.PathValue("target"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + return + } + } +} diff --git a/internal/proxy/controllers/targets/types.go b/internal/proxy/controllers/targets/types.go index 4f9a4cd..c2b037c 100644 --- a/internal/proxy/controllers/targets/types.go +++ b/internal/proxy/controllers/targets/types.go @@ -1,20 +1,20 @@ -//go:build linux - -package targets - -import ( - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" -) - -type TargetsResponse struct { - Data []types.Target `json:"data"` - Digest string `json:"digest"` -} - -type TargetConfigResponse struct { - Errors map[string]string `json:"errors"` - Message string `json:"message"` - Data *types.Target `json:"data"` - Status int `json:"status"` - Success bool `json:"success"` -} +//go:build linux + +package targets + +import ( + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" +) + +type TargetsResponse struct { + Data []types.Target `json:"data"` + Digest string `json:"digest"` +} + +type TargetConfigResponse struct { + Errors map[string]string `json:"errors"` + Message string `json:"message"` + Data *types.Target `json:"data"` + Status int `json:"status"` + Success bool `json:"success"` +} diff --git a/internal/proxy/controllers/tokens/tokens.go b/internal/proxy/controllers/tokens/tokens.go index f189771..49aad1a 100644 --- a/internal/proxy/controllers/tokens/tokens.go +++ b/internal/proxy/controllers/tokens/tokens.go @@ -1,122 +1,122 @@ -//go:build linux - -package tokens - -import ( - "encoding/json" - "net/http" - - "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" - "github.com/sonroyaalmerol/pbs-plus/internal/store" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -func D2DTokenHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - all, err := storeInstance.Database.GetAllTokens() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - digest, err := utils.CalculateDigest(all) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - toReturn := TokensResponse{ - Data: all, - Digest: digest, - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(toReturn) - - return - } -} - -func ExtJsTokenHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := TokenConfigResponse{} - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "application/json") - - err := r.ParseForm() - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - newToken := types.AgentToken{ - Comment: r.FormValue("comment"), - } - - err = storeInstance.Database.CreateToken(newToken.Comment) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - } -} - -func ExtJsTokenSingleHandler(storeInstance *store.Store) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - response := TokenConfigResponse{} - if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { - http.Error(w, "Invalid HTTP method", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodGet { - token, err := storeInstance.Database.GetToken(utils.DecodePath(r.PathValue("token"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - response.Data = token - json.NewEncoder(w).Encode(response) - - return - } - - if r.Method == http.MethodDelete { - token, err := storeInstance.Database.GetToken(utils.DecodePath(r.PathValue("token"))) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - err = storeInstance.Database.RevokeToken(token) - if err != nil { - controllers.WriteErrorResponse(w, err) - return - } - - response.Status = http.StatusOK - response.Success = true - json.NewEncoder(w).Encode(response) - return - } - } -} +//go:build linux + +package tokens + +import ( + "encoding/json" + "net/http" + + "github.com/sonroyaalmerol/pbs-plus/internal/proxy/controllers" + "github.com/sonroyaalmerol/pbs-plus/internal/store" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +func D2DTokenHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + all, err := storeInstance.Database.GetAllTokens() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + digest, err := utils.CalculateDigest(all) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + toReturn := TokensResponse{ + Data: all, + Digest: digest, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(toReturn) + + return + } +} + +func ExtJsTokenHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := TokenConfigResponse{} + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + err := r.ParseForm() + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + newToken := types.AgentToken{ + Comment: r.FormValue("comment"), + } + + err = storeInstance.Database.CreateToken(newToken.Comment) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + } +} + +func ExtJsTokenSingleHandler(storeInstance *store.Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + response := TokenConfigResponse{} + if r.Method != http.MethodPut && r.Method != http.MethodGet && r.Method != http.MethodDelete { + http.Error(w, "Invalid HTTP method", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + if r.Method == http.MethodGet { + token, err := storeInstance.Database.GetToken(utils.DecodePath(r.PathValue("token"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + response.Data = token + json.NewEncoder(w).Encode(response) + + return + } + + if r.Method == http.MethodDelete { + token, err := storeInstance.Database.GetToken(utils.DecodePath(r.PathValue("token"))) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + err = storeInstance.Database.RevokeToken(token) + if err != nil { + controllers.WriteErrorResponse(w, err) + return + } + + response.Status = http.StatusOK + response.Success = true + json.NewEncoder(w).Encode(response) + return + } + } +} diff --git a/internal/proxy/controllers/tokens/types.go b/internal/proxy/controllers/tokens/types.go index a6c38f3..265bdc1 100644 --- a/internal/proxy/controllers/tokens/types.go +++ b/internal/proxy/controllers/tokens/types.go @@ -1,20 +1,20 @@ -//go:build linux - -package tokens - -import ( - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" -) - -type TokensResponse struct { - Data []types.AgentToken `json:"data"` - Digest string `json:"digest"` -} - -type TokenConfigResponse struct { - Errors map[string]string `json:"errors"` - Message string `json:"message"` - Data *types.AgentToken `json:"data"` - Status int `json:"status"` - Success bool `json:"success"` -} +//go:build linux + +package tokens + +import ( + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" +) + +type TokensResponse struct { + Data []types.AgentToken `json:"data"` + Digest string `json:"digest"` +} + +type TokenConfigResponse struct { + Errors map[string]string `json:"errors"` + Message string `json:"message"` + Data *types.AgentToken `json:"data"` + Status int `json:"status"` + Success bool `json:"success"` +} diff --git a/internal/proxy/js_compiler.go b/internal/proxy/js_compiler.go index cb732f8..8726ad7 100644 --- a/internal/proxy/js_compiler.go +++ b/internal/proxy/js_compiler.go @@ -1,208 +1,208 @@ -//go:build linux - -package proxy - -import ( - "embed" - "fmt" - "io/fs" - "log" - "os" - "path/filepath" - "strings" - "syscall" - - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -//go:embed all:views -var customJsFS embed.FS - -func compileCustomJS() []byte { - result := []byte(` -const pbsFullUrl = window.location.href; -const pbsUrl = new URL(pbsFullUrl); -const pbsPlusBaseUrl = ` + "`${pbsUrl.protocol}//${pbsUrl.hostname}:8008`" + `; - -function getCookie(cName) { - const name = cName + "="; - const cDecoded = decodeURIComponent(document.cookie); - const cArr = cDecoded.split('; '); - let res; - cArr.forEach(val => { - if (val.indexOf(name) === 0) res = val.substring(name.length); - }) - return res -} - -var pbsPlusTokenHeaders = { - "Content-Type": "application/json", -}; - -if (Proxmox.CSRFPreventionToken) { - pbsPlusTokenHeaders["Csrfpreventiontoken"] = Proxmox.CSRFPreventionToken; -} - -fetch(pbsPlusBaseUrl + "/plus/token", { - method: "POST", - body: JSON.stringify({ - "pbs_auth_cookie": getCookie("PBSAuthCookie"), - }), - headers: pbsPlusTokenHeaders, -}) - -function encodePathValue(path) { - const encoded = btoa(path) - .replace(/\+/g, '-') - .replace(/\//g, '_') - .replace(/=+$/, ''); - return encoded; -} -`) - - err := fs.WalkDir(customJsFS, ".", func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if d.IsDir() { - return nil - } - content, err := customJsFS.ReadFile(path) - if err != nil { - return err - } - result = append(result, content...) - result = append(result, []byte("\n")...) - return nil - }) - if err != nil { - log.Println(err) - } - return result -} - -// MountCompiledJS creates a backup of the target file and mounts the compiled JS over it -func MountCompiledJS(targetPath string) error { - // Check if something is already mounted at the target path - if utils.IsMounted(targetPath) { - if err := syscall.Unmount(targetPath, 0); err != nil { - return fmt.Errorf("failed to unmount existing file: %w", err) - } - } - - // Create backup directory if it doesn't exist - backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") - if err := os.MkdirAll(backupDir, 0755); err != nil { - return fmt.Errorf("failed to create backup directory: %w", err) - } - - // Create backup filename with timestamp - backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(targetPath))) - - // Read existing file - original, err := os.ReadFile(targetPath) - if err != nil { - return fmt.Errorf("failed to read original file: %w", err) - } - - // Create backup - if err := os.WriteFile(backupPath, original, 0644); err != nil { - return fmt.Errorf("failed to create backup: %w", err) - } - - // Create new file with compiled JS - compiledJS := compileCustomJS() - - newContent := make([]byte, len(original)+1+len(compiledJS)) - copy(newContent, original) - newContent[len(original)] = '\n' // Add newline - copy(newContent[len(original)+1:], compiledJS) - - tempFile := filepath.Join(backupDir, filepath.Base(targetPath)) - if err := os.WriteFile(tempFile, newContent, 0644); err != nil { - return fmt.Errorf("failed to write new content: %w", err) - } - - // Perform bind mount - if err := syscall.Mount(tempFile, targetPath, "", syscall.MS_BIND, ""); err != nil { - return fmt.Errorf("failed to mount file: %w", err) - } - - return nil -} - -func MountModdedProxmoxLib(targetPath string) error { - // Check if something is already mounted at the target path - if utils.IsMounted(targetPath) { - if err := syscall.Unmount(targetPath, 0); err != nil { - return fmt.Errorf("failed to unmount existing file: %w", err) - } - } - - // Create backup directory if it doesn't exist - backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") - if err := os.MkdirAll(backupDir, 0755); err != nil { - return fmt.Errorf("failed to create backup directory: %w", err) - } - - // Create backup filename with timestamp - backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(targetPath))) - - // Read existing file - original, err := os.ReadFile(targetPath) - if err != nil { - return fmt.Errorf("failed to read original file: %w", err) - } - - // Create backup - if err := os.WriteFile(backupPath, original, 0644); err != nil { - return fmt.Errorf("failed to create backup: %w", err) - } - - oldString := `if (!newopts.url.match(/^\/api2/))` - newString := `if (!newopts.url.match(/^\/api2/) && !newopts.url.match(/^[a-z][a-z\d+\-.]*:/i))` - - // Perform the replacement - newContent := strings.Replace(string(original), oldString, newString, 1) - - tempFile := filepath.Join(backupDir, filepath.Base(targetPath)) - if err := os.WriteFile(tempFile, []byte(newContent), 0644); err != nil { - return fmt.Errorf("failed to write new content: %w", err) - } - - // Perform bind mount - if err := syscall.Mount(tempFile, targetPath, "", syscall.MS_BIND, ""); err != nil { - return fmt.Errorf("failed to mount file: %w", err) - } - - return nil -} - -// UnmountCompiledJS unmounts the file and restores the original -func UnmountModdedFile(targetPath string) error { - // Unmount the file - if err := syscall.Unmount(targetPath, 0); err != nil { - return fmt.Errorf("failed to unmount file: %w", err) - } - - // Path to backup file - backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") - backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(targetPath))) - - // Restore from backup if it exists - if _, err := os.Stat(backupPath); err == nil { - backup, err := os.ReadFile(backupPath) - if err != nil { - return fmt.Errorf("failed to read backup: %w", err) - } - - if err := os.WriteFile(targetPath, backup, 0644); err != nil { - return fmt.Errorf("failed to restore backup: %w", err) - } - - // Clean up backup files - os.RemoveAll(backupDir) - } - - return nil -} +//go:build linux + +package proxy + +import ( + "embed" + "fmt" + "io/fs" + "log" + "os" + "path/filepath" + "strings" + "syscall" + + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +//go:embed all:views +var customJsFS embed.FS + +func compileCustomJS() []byte { + result := []byte(` +const pbsFullUrl = window.location.href; +const pbsUrl = new URL(pbsFullUrl); +const pbsPlusBaseUrl = ` + "`${pbsUrl.protocol}//${pbsUrl.hostname}:8008`" + `; + +function getCookie(cName) { + const name = cName + "="; + const cDecoded = decodeURIComponent(document.cookie); + const cArr = cDecoded.split('; '); + let res; + cArr.forEach(val => { + if (val.indexOf(name) === 0) res = val.substring(name.length); + }) + return res +} + +var pbsPlusTokenHeaders = { + "Content-Type": "application/json", +}; + +if (Proxmox.CSRFPreventionToken) { + pbsPlusTokenHeaders["Csrfpreventiontoken"] = Proxmox.CSRFPreventionToken; +} + +fetch(pbsPlusBaseUrl + "/plus/token", { + method: "POST", + body: JSON.stringify({ + "pbs_auth_cookie": getCookie("PBSAuthCookie"), + }), + headers: pbsPlusTokenHeaders, +}) + +function encodePathValue(path) { + const encoded = btoa(path) + .replace(/\+/g, '-') + .replace(/\//g, '_') + .replace(/=+$/, ''); + return encoded; +} +`) + + err := fs.WalkDir(customJsFS, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + content, err := customJsFS.ReadFile(path) + if err != nil { + return err + } + result = append(result, content...) + result = append(result, []byte("\n")...) + return nil + }) + if err != nil { + log.Println(err) + } + return result +} + +// MountCompiledJS creates a backup of the target file and mounts the compiled JS over it +func MountCompiledJS(targetPath string) error { + // Check if something is already mounted at the target path + if utils.IsMounted(targetPath) { + if err := syscall.Unmount(targetPath, 0); err != nil { + return fmt.Errorf("failed to unmount existing file: %w", err) + } + } + + // Create backup directory if it doesn't exist + backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") + if err := os.MkdirAll(backupDir, 0755); err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + // Create backup filename with timestamp + backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(targetPath))) + + // Read existing file + original, err := os.ReadFile(targetPath) + if err != nil { + return fmt.Errorf("failed to read original file: %w", err) + } + + // Create backup + if err := os.WriteFile(backupPath, original, 0644); err != nil { + return fmt.Errorf("failed to create backup: %w", err) + } + + // Create new file with compiled JS + compiledJS := compileCustomJS() + + newContent := make([]byte, len(original)+1+len(compiledJS)) + copy(newContent, original) + newContent[len(original)] = '\n' // Add newline + copy(newContent[len(original)+1:], compiledJS) + + tempFile := filepath.Join(backupDir, filepath.Base(targetPath)) + if err := os.WriteFile(tempFile, newContent, 0644); err != nil { + return fmt.Errorf("failed to write new content: %w", err) + } + + // Perform bind mount + if err := syscall.Mount(tempFile, targetPath, "", syscall.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to mount file: %w", err) + } + + return nil +} + +func MountModdedProxmoxLib(targetPath string) error { + // Check if something is already mounted at the target path + if utils.IsMounted(targetPath) { + if err := syscall.Unmount(targetPath, 0); err != nil { + return fmt.Errorf("failed to unmount existing file: %w", err) + } + } + + // Create backup directory if it doesn't exist + backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") + if err := os.MkdirAll(backupDir, 0755); err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + // Create backup filename with timestamp + backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(targetPath))) + + // Read existing file + original, err := os.ReadFile(targetPath) + if err != nil { + return fmt.Errorf("failed to read original file: %w", err) + } + + // Create backup + if err := os.WriteFile(backupPath, original, 0644); err != nil { + return fmt.Errorf("failed to create backup: %w", err) + } + + oldString := `if (!newopts.url.match(/^\/api2/))` + newString := `if (!newopts.url.match(/^\/api2/) && !newopts.url.match(/^[a-z][a-z\d+\-.]*:/i))` + + // Perform the replacement + newContent := strings.Replace(string(original), oldString, newString, 1) + + tempFile := filepath.Join(backupDir, filepath.Base(targetPath)) + if err := os.WriteFile(tempFile, []byte(newContent), 0644); err != nil { + return fmt.Errorf("failed to write new content: %w", err) + } + + // Perform bind mount + if err := syscall.Mount(tempFile, targetPath, "", syscall.MS_BIND, ""); err != nil { + return fmt.Errorf("failed to mount file: %w", err) + } + + return nil +} + +// UnmountCompiledJS unmounts the file and restores the original +func UnmountModdedFile(targetPath string) error { + // Unmount the file + if err := syscall.Unmount(targetPath, 0); err != nil { + return fmt.Errorf("failed to unmount file: %w", err) + } + + // Path to backup file + backupDir := filepath.Join(os.TempDir(), "pbs-plus-backups") + backupPath := filepath.Join(backupDir, fmt.Sprintf("%s.backup", filepath.Base(targetPath))) + + // Restore from backup if it exists + if _, err := os.Stat(backupPath); err == nil { + backup, err := os.ReadFile(backupPath) + if err != nil { + return fmt.Errorf("failed to read backup: %w", err) + } + + if err := os.WriteFile(targetPath, backup, 0644); err != nil { + return fmt.Errorf("failed to restore backup: %w", err) + } + + // Clean up backup files + os.RemoveAll(backupDir) + } + + return nil +} diff --git a/internal/proxy/middlewares/auth.go b/internal/proxy/middlewares/auth.go index 84a53bd..ef23f8a 100644 --- a/internal/proxy/middlewares/auth.go +++ b/internal/proxy/middlewares/auth.go @@ -1,155 +1,155 @@ -//go:build linux - -package middlewares - -import ( - "crypto/x509" - "encoding/base64" - "encoding/pem" - "fmt" - "net/http" - - "github.com/sonroyaalmerol/pbs-plus/internal/store" -) - -func AgentOnly(store *store.Store, next http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if err := checkAgentAuth(store, r); err != nil { - http.Error(w, "authentication failed - no authentication credentials provided", http.StatusUnauthorized) - return - } - - next.ServeHTTP(w, r) - } -} - -func ServerOnly(store *store.Store, next http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if err := checkProxyAuth(r); err != nil { - http.Error(w, "authentication failed - no authentication credentials provided", http.StatusUnauthorized) - return - } - - next.ServeHTTP(w, r) - } -} - -func AgentOrServer(store *store.Store, next http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - authenticated := false - - if err := checkAgentAuth(store, r); err == nil { - authenticated = true - } - - if err := checkProxyAuth(r); err == nil { - authenticated = true - } - - if !authenticated { - http.Error(w, "authentication failed - no authentication credentials provided", http.StatusUnauthorized) - return - } - - next.ServeHTTP(w, r) - } -} - -func checkAgentAuth(store *store.Store, r *http.Request) error { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - return fmt.Errorf("CheckAgentAuth: client certificate required") - } - - agentHostname := r.Header.Get("X-PBS-Agent") - if agentHostname == "" { - return fmt.Errorf("CheckAgentAuth: missing X-PBS-Agent header") - } - - trustedCert, err := loadTrustedCert(store, agentHostname+" - C") - if err != nil { - return fmt.Errorf("CheckAgentAuth: certificate not trusted") - } - - clientCert := r.TLS.PeerCertificates[0] - - if !clientCert.Equal(trustedCert) { - return fmt.Errorf("certificate does not match pinned certificate") - } - - return nil -} - -func checkProxyAuth(r *http.Request) error { - agentHostname := r.Header.Get("X-PBS-Agent") - if agentHostname != "" { - return fmt.Errorf("CheckProxyAuth: agent unauthorized") - } - // checkEndpoint := "/api2/json/version" - // req, err := http.NewRequest( - // http.MethodGet, - // fmt.Sprintf( - // "%s%s", - // ProxyTargetURL, - // checkEndpoint, - // ), - // nil, - // ) - - // if err != nil { - // return fmt.Errorf("CheckProxyAuth: error creating http request -> %w", err) - // } - - // for _, cookie := range r.Cookies() { - // req.AddCookie(cookie) - // } - - // if authHead := r.Header.Get("Authorization"); authHead != "" { - // req.Header.Set("Authorization", authHead) - // } - - // if storeInstance.HTTPClient == nil { - // storeInstance.HTTPClient = &http.Client{ - // Timeout: time.Second * 30, - // Transport: utils.BaseTransport, - // } - // } - - // resp, err := storeInstance.HTTPClient.Do(req) - // if err != nil { - // return fmt.Errorf("CheckProxyAuth: invalid auth -> %w", err) - // } - // defer func() { - // _, _ = io.Copy(io.Discard, resp.Body) - // resp.Body.Close() - // }() - - // if resp.StatusCode > 299 || resp.StatusCode < 200 { - // return fmt.Errorf("CheckProxyAuth: invalid auth -> %w", err) - // } - - return nil -} - -func loadTrustedCert(store *store.Store, targetName string) (*x509.Certificate, error) { - target, err := store.Database.GetTarget(targetName) - if err != nil { - return nil, fmt.Errorf("failed to get target: %w", err) - } - - decodedCert, err := base64.StdEncoding.DecodeString(target.Auth) - if err != nil { - return nil, fmt.Errorf("failed to get target cert: %w", err) - } - - block, _ := pem.Decode(decodedCert) - if block == nil { - return nil, fmt.Errorf("failed to decode PEM block") - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - return nil, fmt.Errorf("failed to parse certificate: %v", err) - } - - return cert, nil -} +//go:build linux + +package middlewares + +import ( + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "net/http" + + "github.com/sonroyaalmerol/pbs-plus/internal/store" +) + +func AgentOnly(store *store.Store, next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := checkAgentAuth(store, r); err != nil { + http.Error(w, "authentication failed - no authentication credentials provided", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + } +} + +func ServerOnly(store *store.Store, next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := checkProxyAuth(r); err != nil { + http.Error(w, "authentication failed - no authentication credentials provided", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + } +} + +func AgentOrServer(store *store.Store, next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authenticated := false + + if err := checkAgentAuth(store, r); err == nil { + authenticated = true + } + + if err := checkProxyAuth(r); err == nil { + authenticated = true + } + + if !authenticated { + http.Error(w, "authentication failed - no authentication credentials provided", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + } +} + +func checkAgentAuth(store *store.Store, r *http.Request) error { + if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { + return fmt.Errorf("CheckAgentAuth: client certificate required") + } + + agentHostname := r.Header.Get("X-PBS-Agent") + if agentHostname == "" { + return fmt.Errorf("CheckAgentAuth: missing X-PBS-Agent header") + } + + trustedCert, err := loadTrustedCert(store, agentHostname+" - C") + if err != nil { + return fmt.Errorf("CheckAgentAuth: certificate not trusted") + } + + clientCert := r.TLS.PeerCertificates[0] + + if !clientCert.Equal(trustedCert) { + return fmt.Errorf("certificate does not match pinned certificate") + } + + return nil +} + +func checkProxyAuth(r *http.Request) error { + agentHostname := r.Header.Get("X-PBS-Agent") + if agentHostname != "" { + return fmt.Errorf("CheckProxyAuth: agent unauthorized") + } + // checkEndpoint := "/api2/json/version" + // req, err := http.NewRequest( + // http.MethodGet, + // fmt.Sprintf( + // "%s%s", + // ProxyTargetURL, + // checkEndpoint, + // ), + // nil, + // ) + + // if err != nil { + // return fmt.Errorf("CheckProxyAuth: error creating http request -> %w", err) + // } + + // for _, cookie := range r.Cookies() { + // req.AddCookie(cookie) + // } + + // if authHead := r.Header.Get("Authorization"); authHead != "" { + // req.Header.Set("Authorization", authHead) + // } + + // if storeInstance.HTTPClient == nil { + // storeInstance.HTTPClient = &http.Client{ + // Timeout: time.Second * 30, + // Transport: utils.BaseTransport, + // } + // } + + // resp, err := storeInstance.HTTPClient.Do(req) + // if err != nil { + // return fmt.Errorf("CheckProxyAuth: invalid auth -> %w", err) + // } + // defer func() { + // _, _ = io.Copy(io.Discard, resp.Body) + // resp.Body.Close() + // }() + + // if resp.StatusCode > 299 || resp.StatusCode < 200 { + // return fmt.Errorf("CheckProxyAuth: invalid auth -> %w", err) + // } + + return nil +} + +func loadTrustedCert(store *store.Store, targetName string) (*x509.Certificate, error) { + target, err := store.Database.GetTarget(targetName) + if err != nil { + return nil, fmt.Errorf("failed to get target: %w", err) + } + + decodedCert, err := base64.StdEncoding.DecodeString(target.Auth) + if err != nil { + return nil, fmt.Errorf("failed to get target cert: %w", err) + } + + block, _ := pem.Decode(decodedCert) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %v", err) + } + + return cert, nil +} diff --git a/internal/proxy/middlewares/token.go b/internal/proxy/middlewares/token.go index db84d61..cc1ced3 100644 --- a/internal/proxy/middlewares/token.go +++ b/internal/proxy/middlewares/token.go @@ -1,43 +1,43 @@ -//go:build linux - -package middlewares - -import ( - "log" - "net/http" - - "github.com/sonroyaalmerol/pbs-plus/internal/store" -) - -func CORS(store *store.Store, next http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - allowedOrigin := r.Header.Get("Origin") - if allowedOrigin != "" { - allowedHeaders := r.Header.Get("Access-Control-Request-Headers") - if allowedHeaders == "" { - allowedHeaders = "Content-Type, *" - } - - allowedMethods := r.Header.Get("Access-Control-Request-Method") - if allowedMethods == "" { - allowedMethods = "POST, GET, OPTIONS, PUT, DELETE" - } - - w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) - w.Header().Set("Access-Control-Allow-Methods", allowedMethods) - w.Header().Set("Access-Control-Allow-Headers", allowedHeaders) - w.Header().Set("Access-Control-Allow-Credentials", "true") - } - - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte{}) - if err != nil { - log.Printf("cannot send 200 answer → %v", err) - } - return - } - - next.ServeHTTP(w, r) - } -} +//go:build linux + +package middlewares + +import ( + "log" + "net/http" + + "github.com/sonroyaalmerol/pbs-plus/internal/store" +) + +func CORS(store *store.Store, next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + allowedOrigin := r.Header.Get("Origin") + if allowedOrigin != "" { + allowedHeaders := r.Header.Get("Access-Control-Request-Headers") + if allowedHeaders == "" { + allowedHeaders = "Content-Type, *" + } + + allowedMethods := r.Header.Get("Access-Control-Request-Method") + if allowedMethods == "" { + allowedMethods = "POST, GET, OPTIONS, PUT, DELETE" + } + + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + w.Header().Set("Access-Control-Allow-Methods", allowedMethods) + w.Header().Set("Access-Control-Allow-Headers", allowedHeaders) + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte{}) + if err != nil { + log.Printf("cannot send 200 answer → %v", err) + } + return + } + + next.ServeHTTP(w, r) + } +} diff --git a/internal/proxy/views/d2d_backup/disk_backup.js b/internal/proxy/views/d2d_backup/disk_backup.js index 35a5b77..6ac17eb 100644 --- a/internal/proxy/views/d2d_backup/disk_backup.js +++ b/internal/proxy/views/d2d_backup/disk_backup.js @@ -1,41 +1,41 @@ -Ext.define("PBS.D2DManagement", { - extend: "Ext.tab.Panel", - alias: "widget.pbsD2DManagement", - - title: "Disk Backup", - - tools: [], - - border: true, - defaults: { - border: false, - xtype: "panel", - }, - - items: [ - { - xtype: "pbsDiskBackupJobView", - title: gettext("Backup Jobs"), - itemId: "d2d-backup-jobs", - iconCls: "fa fa-floppy-o", - }, - { - xtype: "pbsDiskTokenPanel", - title: "Agent Bootstrap", - itemId: "tokens", - iconCls: "fa fa-handshake-o", - }, - { - xtype: "pbsDiskTargetPanel", - title: "Targets", - itemId: "targets", - iconCls: "fa fa-desktop", - }, - { - xtype: "pbsDiskExclusionPanel", - title: "Global Exclusions", - itemId: "exclusions", - iconCls: "fa fa-ban", - }, - ], -}); +Ext.define("PBS.D2DManagement", { + extend: "Ext.tab.Panel", + alias: "widget.pbsD2DManagement", + + title: "Disk Backup", + + tools: [], + + border: true, + defaults: { + border: false, + xtype: "panel", + }, + + items: [ + { + xtype: "pbsDiskBackupJobView", + title: gettext("Backup Jobs"), + itemId: "d2d-backup-jobs", + iconCls: "fa fa-floppy-o", + }, + { + xtype: "pbsDiskTokenPanel", + title: "Agent Bootstrap", + itemId: "tokens", + iconCls: "fa fa-handshake-o", + }, + { + xtype: "pbsDiskTargetPanel", + title: "Targets", + itemId: "targets", + iconCls: "fa fa-desktop", + }, + { + xtype: "pbsDiskExclusionPanel", + title: "Global Exclusions", + itemId: "exclusions", + iconCls: "fa fa-ban", + }, + ], +}); diff --git a/internal/proxy/views/d2d_backup/models.js b/internal/proxy/views/d2d_backup/models.js index 522efbd..0f1549a 100644 --- a/internal/proxy/views/d2d_backup/models.js +++ b/internal/proxy/views/d2d_backup/models.js @@ -1,54 +1,54 @@ -Ext.define("pbs-disk-backup-job-status", { - extend: "Ext.data.Model", - fields: [ - "id", - "store", - "target", - "subpath", - "schedule", - "comment", - "duration", - "next-run", - "last-run-upid", - "last-run-state", - "last-run-endtime", - "rawexclusions", - ], - idProperty: "id", - proxy: { - type: "proxmox", - url: pbsPlusBaseUrl + "/api2/json/d2d/backup", - }, -}); - -Ext.define("pbs-model-targets", { - extend: "Ext.data.Model", - fields: [ - "name", - "path", - "drive_type", - "agent_version", - "connection_status", - "drive_name", - "drive_fs", - "drive_total_bytes", - "drive_used_bytes", - "drive_free_bytes", - "drive_total", - "drive_used", - "drive_free", - ], - idProperty: "name", -}); - -Ext.define("pbs-model-tokens", { - extend: "Ext.data.Model", - fields: ["token", "comment", "created_at", "revoked"], - idProperty: "token", -}); - -Ext.define("pbs-model-exclusions", { - extend: "Ext.data.Model", - fields: ["path", "comment"], - idProperty: "path", -}); +Ext.define("pbs-disk-backup-job-status", { + extend: "Ext.data.Model", + fields: [ + "id", + "store", + "target", + "subpath", + "schedule", + "comment", + "duration", + "next-run", + "last-run-upid", + "last-run-state", + "last-run-endtime", + "rawexclusions", + ], + idProperty: "id", + proxy: { + type: "proxmox", + url: pbsPlusBaseUrl + "/api2/json/d2d/backup", + }, +}); + +Ext.define("pbs-model-targets", { + extend: "Ext.data.Model", + fields: [ + "name", + "path", + "drive_type", + "agent_version", + "connection_status", + "drive_name", + "drive_fs", + "drive_total_bytes", + "drive_used_bytes", + "drive_free_bytes", + "drive_total", + "drive_used", + "drive_free", + ], + idProperty: "name", +}); + +Ext.define("pbs-model-tokens", { + extend: "Ext.data.Model", + fields: ["token", "comment", "created_at", "revoked"], + idProperty: "token", +}); + +Ext.define("pbs-model-exclusions", { + extend: "Ext.data.Model", + fields: ["path", "comment"], + idProperty: "path", +}); diff --git a/internal/proxy/views/d2d_backup/panels/exclusions.js b/internal/proxy/views/d2d_backup/panels/exclusions.js index 0f91b4d..3845d4e 100644 --- a/internal/proxy/views/d2d_backup/panels/exclusions.js +++ b/internal/proxy/views/d2d_backup/panels/exclusions.js @@ -1,110 +1,110 @@ -Ext.define("PBS.D2DManagement.ExclusionPanel", { - extend: "Ext.grid.Panel", - alias: "widget.pbsDiskExclusionPanel", - - controller: { - xclass: "Ext.app.ViewController", - - onAdd: function () { - let me = this; - Ext.create("PBS.D2DManagement.ExclusionEditWindow", { - listeners: { - destroy: function () { - me.reload(); - }, - }, - }).show(); - }, - - onEdit: function () { - let me = this; - let view = me.getView(); - let selection = view.getSelection(); - if (!selection || selection.length < 1) { - return; - } - Ext.create("PBS.D2DManagement.ExclusionEditWindow", { - contentid: selection[0].data.path, - autoLoad: true, - listeners: { - destroy: () => me.reload(), - }, - }).show(); - }, - - reload: function () { - this.getView().getStore().rstore.load(); - }, - - stopStore: function () { - this.getView().getStore().rstore.stopUpdate(); - }, - - startStore: function () { - this.getView().getStore().rstore.startUpdate(); - }, - - init: function (view) { - Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); - }, - }, - - listeners: { - beforedestroy: "stopStore", - deactivate: "stopStore", - activate: "startStore", - itemdblclick: "onEdit", - }, - - store: { - type: "diff", - rstore: { - type: "update", - storeid: "proxmox-disk-exclusions", - model: "pbs-model-exclusions", - proxy: { - type: "proxmox", - url: pbsPlusBaseUrl + "/api2/json/d2d/exclusion", - }, - }, - sorters: "name", - }, - - features: [], - - tbar: [ - { - text: gettext("Add"), - xtype: "proxmoxButton", - handler: "onAdd", - selModel: false, - }, - "-", - { - text: gettext("Edit"), - xtype: "proxmoxButton", - handler: "onEdit", - disabled: true, - }, - { - xtype: "proxmoxStdRemoveButton", - baseurl: pbsPlusBaseUrl + "/api2/extjs/config/d2d-exclusion", - getUrl: (rec) => - pbsPlusBaseUrl + - `/api2/extjs/config/d2d-exclusion/${encodeURIComponent(encodePathValue(rec.getId()))}`, - callback: "reload", - }, - ], - columns: [ - { - text: gettext("Path"), - dataIndex: "path", - flex: 1, - }, - { - text: gettext("Comment"), - dataIndex: "comment", - flex: 2, - }, - ], -}); +Ext.define("PBS.D2DManagement.ExclusionPanel", { + extend: "Ext.grid.Panel", + alias: "widget.pbsDiskExclusionPanel", + + controller: { + xclass: "Ext.app.ViewController", + + onAdd: function () { + let me = this; + Ext.create("PBS.D2DManagement.ExclusionEditWindow", { + listeners: { + destroy: function () { + me.reload(); + }, + }, + }).show(); + }, + + onEdit: function () { + let me = this; + let view = me.getView(); + let selection = view.getSelection(); + if (!selection || selection.length < 1) { + return; + } + Ext.create("PBS.D2DManagement.ExclusionEditWindow", { + contentid: selection[0].data.path, + autoLoad: true, + listeners: { + destroy: () => me.reload(), + }, + }).show(); + }, + + reload: function () { + this.getView().getStore().rstore.load(); + }, + + stopStore: function () { + this.getView().getStore().rstore.stopUpdate(); + }, + + startStore: function () { + this.getView().getStore().rstore.startUpdate(); + }, + + init: function (view) { + Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); + }, + }, + + listeners: { + beforedestroy: "stopStore", + deactivate: "stopStore", + activate: "startStore", + itemdblclick: "onEdit", + }, + + store: { + type: "diff", + rstore: { + type: "update", + storeid: "proxmox-disk-exclusions", + model: "pbs-model-exclusions", + proxy: { + type: "proxmox", + url: pbsPlusBaseUrl + "/api2/json/d2d/exclusion", + }, + }, + sorters: "name", + }, + + features: [], + + tbar: [ + { + text: gettext("Add"), + xtype: "proxmoxButton", + handler: "onAdd", + selModel: false, + }, + "-", + { + text: gettext("Edit"), + xtype: "proxmoxButton", + handler: "onEdit", + disabled: true, + }, + { + xtype: "proxmoxStdRemoveButton", + baseurl: pbsPlusBaseUrl + "/api2/extjs/config/d2d-exclusion", + getUrl: (rec) => + pbsPlusBaseUrl + + `/api2/extjs/config/d2d-exclusion/${encodeURIComponent(encodePathValue(rec.getId()))}`, + callback: "reload", + }, + ], + columns: [ + { + text: gettext("Path"), + dataIndex: "path", + flex: 1, + }, + { + text: gettext("Comment"), + dataIndex: "comment", + flex: 2, + }, + ], +}); diff --git a/internal/proxy/views/d2d_backup/panels/jobs.js b/internal/proxy/views/d2d_backup/panels/jobs.js index c7b1648..9a44d0e 100644 --- a/internal/proxy/views/d2d_backup/panels/jobs.js +++ b/internal/proxy/views/d2d_backup/panels/jobs.js @@ -1,231 +1,231 @@ -Ext.define("PBS.config.DiskBackupJobView", { - extend: "Ext.grid.GridPanel", - alias: "widget.pbsDiskBackupJobView", - - stateful: true, - stateId: "grid-disk-backup-jobs-v1", - - title: "Disk Backup Jobs", - - controller: { - xclass: "Ext.app.ViewController", - - addJob: function () { - let me = this; - Ext.create("PBS.D2DManagement.BackupJobEdit", { - autoShow: true, - listeners: { - destroy: function () { - me.reload(); - }, - }, - }).show(); - }, - - editJob: function () { - let me = this; - let view = me.getView(); - let selection = view.getSelection(); - if (!selection || selection.length < 1) { - return; - } - - Ext.create("PBS.D2DManagement.BackupJobEdit", { - id: selection[0].data.id, - autoShow: true, - listeners: { - destroy: function () { - me.reload(); - }, - }, - }).show(); - }, - - openTaskLog: function () { - let me = this; - let view = me.getView(); - let selection = view.getSelection(); - if (selection.length < 1) return; - - let upid = selection[0].data["last-run-upid"]; - if (!upid) return; - - Ext.create("Proxmox.window.TaskViewer", { - upid, - }).show(); - }, - - runJob: function () { - let me = this; - let view = me.getView(); - let selection = view.getSelection(); - if (selection.length < 1) return; - - let id = selection[0].data.id; - - Ext.create("PBS.D2DManagement.BackupWindow", { - id, - listeners: { - destroy: function () { - me.reload(); - }, - }, - }).show(); - }, - - startStore: function () { - this.getView().getStore().rstore.startUpdate(); - }, - - stopStore: function () { - this.getView().getStore().rstore.stopUpdate(); - }, - - reload: function () { - this.getView().getStore().rstore.load(); - }, - - init: function (view) { - Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); - }, - }, - - listeners: { - activate: "startStore", - deactivate: "stopStore", - itemdblclick: "editJob", - }, - - store: { - type: "diff", - autoDestroy: true, - autoDestroyRstore: true, - sorters: "id", - rstore: { - type: "update", - storeid: "pbs-disk-backup-job-status", - model: "pbs-disk-backup-job-status", - interval: 5000, - }, - }, - - viewConfig: { - trackOver: false, - }, - - tbar: [ - { - xtype: "proxmoxButton", - text: gettext("Add"), - selModel: false, - handler: "addJob", - }, - { - xtype: "proxmoxButton", - text: gettext("Edit"), - handler: "editJob", - disabled: true, - }, - { - xtype: "proxmoxStdRemoveButton", - baseurl: pbsPlusBaseUrl + "/api2/extjs/config/disk-backup-job", - getUrl: (rec) => - pbsPlusBaseUrl + - `/api2/extjs/config/disk-backup-job/${encodeURIComponent(encodePathValue(rec.getId()))}`, - confirmMsg: gettext("Remove entry?"), - callback: "reload", - }, - "-", - { - xtype: "proxmoxButton", - text: gettext("Show Log"), - handler: "openTaskLog", - enableFn: (rec) => !!rec.data["last-run-upid"], - disabled: true, - }, - { - xtype: "proxmoxButton", - text: gettext("Run now"), - handler: "runJob", - reference: "d2dBackupRun", - disabled: true, - }, - ], - - columns: [ - { - header: gettext("Job ID"), - dataIndex: "id", - renderer: Ext.String.htmlEncode, - maxWidth: 220, - minWidth: 75, - flex: 1, - sortable: true, - }, - { - header: gettext("Target"), - dataIndex: "target", - width: 120, - sortable: true, - }, - { - header: gettext("Subpath"), - dataIndex: "subpath", - width: 120, - sortable: true, - }, - { - header: gettext("Datastore"), - dataIndex: "store", - width: 120, - sortable: true, - }, - { - header: gettext("Schedule"), - dataIndex: "schedule", - maxWidth: 220, - minWidth: 80, - flex: 1, - sortable: true, - }, - { - header: gettext("Last Backup"), - dataIndex: "last-run-endtime", - renderer: PBS.Utils.render_optional_timestamp, - width: 150, - sortable: true, - }, - { - text: gettext("Duration"), - dataIndex: "duration", - renderer: Proxmox.Utils.render_duration, - width: 80, - }, - { - header: gettext("Status"), - dataIndex: "last-run-state", - renderer: PBS.Utils.render_task_status, - flex: 1, - }, - { - header: gettext("Next Run"), - dataIndex: "next-run", - renderer: PBS.Utils.render_next_task_run, - width: 150, - sortable: true, - }, - { - header: gettext("Comment"), - dataIndex: "comment", - renderer: Ext.String.htmlEncode, - flex: 2, - sortable: true, - }, - ], - - initComponent: function () { - let me = this; - - me.callParent(); - }, -}); +Ext.define("PBS.config.DiskBackupJobView", { + extend: "Ext.grid.GridPanel", + alias: "widget.pbsDiskBackupJobView", + + stateful: true, + stateId: "grid-disk-backup-jobs-v1", + + title: "Disk Backup Jobs", + + controller: { + xclass: "Ext.app.ViewController", + + addJob: function () { + let me = this; + Ext.create("PBS.D2DManagement.BackupJobEdit", { + autoShow: true, + listeners: { + destroy: function () { + me.reload(); + }, + }, + }).show(); + }, + + editJob: function () { + let me = this; + let view = me.getView(); + let selection = view.getSelection(); + if (!selection || selection.length < 1) { + return; + } + + Ext.create("PBS.D2DManagement.BackupJobEdit", { + id: selection[0].data.id, + autoShow: true, + listeners: { + destroy: function () { + me.reload(); + }, + }, + }).show(); + }, + + openTaskLog: function () { + let me = this; + let view = me.getView(); + let selection = view.getSelection(); + if (selection.length < 1) return; + + let upid = selection[0].data["last-run-upid"]; + if (!upid) return; + + Ext.create("Proxmox.window.TaskViewer", { + upid, + }).show(); + }, + + runJob: function () { + let me = this; + let view = me.getView(); + let selection = view.getSelection(); + if (selection.length < 1) return; + + let id = selection[0].data.id; + + Ext.create("PBS.D2DManagement.BackupWindow", { + id, + listeners: { + destroy: function () { + me.reload(); + }, + }, + }).show(); + }, + + startStore: function () { + this.getView().getStore().rstore.startUpdate(); + }, + + stopStore: function () { + this.getView().getStore().rstore.stopUpdate(); + }, + + reload: function () { + this.getView().getStore().rstore.load(); + }, + + init: function (view) { + Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); + }, + }, + + listeners: { + activate: "startStore", + deactivate: "stopStore", + itemdblclick: "editJob", + }, + + store: { + type: "diff", + autoDestroy: true, + autoDestroyRstore: true, + sorters: "id", + rstore: { + type: "update", + storeid: "pbs-disk-backup-job-status", + model: "pbs-disk-backup-job-status", + interval: 5000, + }, + }, + + viewConfig: { + trackOver: false, + }, + + tbar: [ + { + xtype: "proxmoxButton", + text: gettext("Add"), + selModel: false, + handler: "addJob", + }, + { + xtype: "proxmoxButton", + text: gettext("Edit"), + handler: "editJob", + disabled: true, + }, + { + xtype: "proxmoxStdRemoveButton", + baseurl: pbsPlusBaseUrl + "/api2/extjs/config/disk-backup-job", + getUrl: (rec) => + pbsPlusBaseUrl + + `/api2/extjs/config/disk-backup-job/${encodeURIComponent(encodePathValue(rec.getId()))}`, + confirmMsg: gettext("Remove entry?"), + callback: "reload", + }, + "-", + { + xtype: "proxmoxButton", + text: gettext("Show Log"), + handler: "openTaskLog", + enableFn: (rec) => !!rec.data["last-run-upid"], + disabled: true, + }, + { + xtype: "proxmoxButton", + text: gettext("Run now"), + handler: "runJob", + reference: "d2dBackupRun", + disabled: true, + }, + ], + + columns: [ + { + header: gettext("Job ID"), + dataIndex: "id", + renderer: Ext.String.htmlEncode, + maxWidth: 220, + minWidth: 75, + flex: 1, + sortable: true, + }, + { + header: gettext("Target"), + dataIndex: "target", + width: 120, + sortable: true, + }, + { + header: gettext("Subpath"), + dataIndex: "subpath", + width: 120, + sortable: true, + }, + { + header: gettext("Datastore"), + dataIndex: "store", + width: 120, + sortable: true, + }, + { + header: gettext("Schedule"), + dataIndex: "schedule", + maxWidth: 220, + minWidth: 80, + flex: 1, + sortable: true, + }, + { + header: gettext("Last Backup"), + dataIndex: "last-run-endtime", + renderer: PBS.Utils.render_optional_timestamp, + width: 150, + sortable: true, + }, + { + text: gettext("Duration"), + dataIndex: "duration", + renderer: Proxmox.Utils.render_duration, + width: 80, + }, + { + header: gettext("Status"), + dataIndex: "last-run-state", + renderer: PBS.Utils.render_task_status, + flex: 1, + }, + { + header: gettext("Next Run"), + dataIndex: "next-run", + renderer: PBS.Utils.render_next_task_run, + width: 150, + sortable: true, + }, + { + header: gettext("Comment"), + dataIndex: "comment", + renderer: Ext.String.htmlEncode, + flex: 2, + sortable: true, + }, + ], + + initComponent: function () { + let me = this; + + me.callParent(); + }, +}); diff --git a/internal/proxy/views/d2d_backup/panels/targets.js b/internal/proxy/views/d2d_backup/panels/targets.js index f7c114d..d6acc4a 100644 --- a/internal/proxy/views/d2d_backup/panels/targets.js +++ b/internal/proxy/views/d2d_backup/panels/targets.js @@ -1,153 +1,153 @@ -Ext.define("PBS.D2DManagement.TargetPanel", { - extend: "Ext.grid.Panel", - alias: "widget.pbsDiskTargetPanel", - - controller: { - xclass: "Ext.app.ViewController", - - onAdd: function () { - let me = this; - Ext.create("PBS.D2DManagement.TargetEditWindow", { - listeners: { - destroy: function () { - me.reload(); - }, - }, - }).show(); - }, - - onEdit: function () { - let me = this; - let view = me.getView(); - let selection = view.getSelection(); - if (!selection || selection.length < 1) { - return; - } - Ext.create("PBS.D2DManagement.TargetEditWindow", { - contentid: selection[0].data.name, - autoLoad: true, - listeners: { - destroy: () => me.reload(), - }, - }).show(); - }, - - reload: function () { - this.getView().getStore().rstore.load(); - }, - - stopStore: function () { - this.getView().getStore().rstore.stopUpdate(); - }, - - startStore: function () { - this.getView().getStore().rstore.startUpdate(); - }, - - render_status: function (value) { - if (value.toString() == "true") { - icon = "check good"; - text = "Reachable"; - } else { - icon = "times critical"; - text = "Unreachable"; - } - - return ` ${text}`; - }, - - init: function (view) { - Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); - }, - }, - - listeners: { - beforedestroy: "stopStore", - deactivate: "stopStore", - activate: "startStore", - itemdblclick: "onEdit", - }, - - store: { - type: "diff", - rstore: { - type: "update", - storeid: "proxmox-disk-targets", - model: "pbs-model-targets", - proxy: { - type: "proxmox", - url: pbsPlusBaseUrl + "/api2/json/d2d/target", - }, - }, - sorters: "name", - }, - - features: [], - - tbar: [ - { - text: gettext("Add"), - xtype: "proxmoxButton", - handler: "onAdd", - selModel: false, - }, - "-", - { - text: gettext("Edit"), - xtype: "proxmoxButton", - handler: "onEdit", - disabled: true, - }, - { - xtype: "proxmoxStdRemoveButton", - baseurl: pbsPlusBaseUrl + "/api2/extjs/config/d2d-target", - getUrl: (rec) => - pbsPlusBaseUrl + - `/api2/extjs/config/d2d-target/${encodeURIComponent(encodePathValue(rec.getId()))}`, - callback: "reload", - }, - ], - columns: [ - { - text: gettext("Name"), - dataIndex: "name", - flex: 1, - }, - { - text: gettext("Path"), - dataIndex: "path", - flex: 2, - }, - { - text: gettext("Drive Type"), - dataIndex: "drive_type", - flex: 1, - }, - { - text: gettext("Drive Name"), - dataIndex: "drive_name", - flex: 1, - }, - { - text: gettext("Drive FS"), - dataIndex: "drive_fs", - flex: 1, - }, - { - text: gettext("Drive Used"), - dataIndex: "drive_used", - flex: 1, - }, - { - header: gettext("Status"), - dataIndex: "connection_status", - renderer: "render_status", - flex: 1, - }, - { - text: gettext("Agent Version"), - dataIndex: "agent_version", - flex: 1, - }, - ], -}); +Ext.define("PBS.D2DManagement.TargetPanel", { + extend: "Ext.grid.Panel", + alias: "widget.pbsDiskTargetPanel", + + controller: { + xclass: "Ext.app.ViewController", + + onAdd: function () { + let me = this; + Ext.create("PBS.D2DManagement.TargetEditWindow", { + listeners: { + destroy: function () { + me.reload(); + }, + }, + }).show(); + }, + + onEdit: function () { + let me = this; + let view = me.getView(); + let selection = view.getSelection(); + if (!selection || selection.length < 1) { + return; + } + Ext.create("PBS.D2DManagement.TargetEditWindow", { + contentid: selection[0].data.name, + autoLoad: true, + listeners: { + destroy: () => me.reload(), + }, + }).show(); + }, + + reload: function () { + this.getView().getStore().rstore.load(); + }, + + stopStore: function () { + this.getView().getStore().rstore.stopUpdate(); + }, + + startStore: function () { + this.getView().getStore().rstore.startUpdate(); + }, + + render_status: function (value) { + if (value.toString() == "true") { + icon = "check good"; + text = "Reachable"; + } else { + icon = "times critical"; + text = "Unreachable"; + } + + return ` ${text}`; + }, + + init: function (view) { + Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); + }, + }, + + listeners: { + beforedestroy: "stopStore", + deactivate: "stopStore", + activate: "startStore", + itemdblclick: "onEdit", + }, + + store: { + type: "diff", + rstore: { + type: "update", + storeid: "proxmox-disk-targets", + model: "pbs-model-targets", + proxy: { + type: "proxmox", + url: pbsPlusBaseUrl + "/api2/json/d2d/target", + }, + }, + sorters: "name", + }, + + features: [], + + tbar: [ + { + text: gettext("Add"), + xtype: "proxmoxButton", + handler: "onAdd", + selModel: false, + }, + "-", + { + text: gettext("Edit"), + xtype: "proxmoxButton", + handler: "onEdit", + disabled: true, + }, + { + xtype: "proxmoxStdRemoveButton", + baseurl: pbsPlusBaseUrl + "/api2/extjs/config/d2d-target", + getUrl: (rec) => + pbsPlusBaseUrl + + `/api2/extjs/config/d2d-target/${encodeURIComponent(encodePathValue(rec.getId()))}`, + callback: "reload", + }, + ], + columns: [ + { + text: gettext("Name"), + dataIndex: "name", + flex: 1, + }, + { + text: gettext("Path"), + dataIndex: "path", + flex: 2, + }, + { + text: gettext("Drive Type"), + dataIndex: "drive_type", + flex: 1, + }, + { + text: gettext("Drive Name"), + dataIndex: "drive_name", + flex: 1, + }, + { + text: gettext("Drive FS"), + dataIndex: "drive_fs", + flex: 1, + }, + { + text: gettext("Drive Used"), + dataIndex: "drive_used", + flex: 1, + }, + { + header: gettext("Status"), + dataIndex: "connection_status", + renderer: "render_status", + flex: 1, + }, + { + text: gettext("Agent Version"), + dataIndex: "agent_version", + flex: 1, + }, + ], +}); diff --git a/internal/proxy/views/d2d_backup/panels/tokens.js b/internal/proxy/views/d2d_backup/panels/tokens.js index 0af29ad..63e048c 100644 --- a/internal/proxy/views/d2d_backup/panels/tokens.js +++ b/internal/proxy/views/d2d_backup/panels/tokens.js @@ -1,161 +1,161 @@ -Ext.define("PBS.D2DManagement.TokenPanel", { - extend: "Ext.grid.Panel", - alias: "widget.pbsDiskTokenPanel", - - controller: { - xclass: "Ext.app.ViewController", - - onAdd: function () { - let me = this; - Ext.create("PBS.D2DManagement.TokenEditWindow", { - listeners: { - destroy: function () { - me.reload(); - }, - }, - }).show(); - }, - - onCopy: async function () { - let me = this; - let view = me.getView(); - let selection = view.getSelection(); - if (!selection || selection.length < 1) { - return; - } - - let token = selection[0].data.token; - Ext.create("Ext.window.Window", { - modal: true, - width: 600, - title: gettext("Bootstrap Token"), - layout: "form", - bodyPadding: "10 0", - items: [ - { - xtype: "textfield", - inputId: "token", - value: token, - editable: false, - }, - ], - buttons: [ - { - xtype: "button", - iconCls: "fa fa-clipboard", - handler: async function (b) { - await navigator.clipboard.writeText(token); - }, - text: gettext("Copy"), - }, - { - text: gettext("Ok"), - handler: function () { - this.up("window").close(); - }, - }, - ], - }).show(); - }, - - reload: function () { - this.getView().getStore().rstore.load(); - }, - - stopStore: function () { - this.getView().getStore().rstore.stopUpdate(); - }, - - startStore: function () { - this.getView().getStore().rstore.startUpdate(); - }, - - render_valid: function (value) { - if (value.toString() == "false") { - icon = "check good"; - text = "Valid"; - } else { - icon = "times critical"; - text = "Invalid"; - } - - return ` ${text}`; - }, - - init: function (view) { - Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); - }, - }, - - listeners: { - beforedestroy: "stopStore", - deactivate: "stopStore", - activate: "startStore", - itemdblclick: "onCopy", - }, - - store: { - type: "diff", - rstore: { - type: "update", - storeid: "proxmox-agent-tokens", - model: "pbs-model-tokens", - proxy: { - type: "proxmox", - url: pbsPlusBaseUrl + "/api2/json/d2d/token", - }, - }, - sorters: "name", - }, - - features: [], - - tbar: [ - { - text: gettext("Generate Token"), - xtype: "proxmoxButton", - handler: "onAdd", - selModel: false, - }, - "-", - { - text: gettext("Copy Token"), - xtype: "proxmoxButton", - handler: "onCopy", - disabled: true, - }, - { - text: gettext("Revoke Token"), - xtype: "proxmoxStdRemoveButton", - baseurl: pbsPlusBaseUrl + "/api2/extjs/config/d2d-token", - getUrl: (rec) => - pbsPlusBaseUrl + - `/api2/extjs/config/d2d-token/${encodeURIComponent(encodePathValue(rec.getId()))}`, - callback: "reload", - }, - ], - columns: [ - { - text: gettext("Token"), - dataIndex: "token", - flex: 1, - }, - { - text: gettext("Comment"), - dataIndex: "comment", - flex: 2, - }, - { - header: gettext("Validity"), - dataIndex: "revoked", - renderer: "render_valid", - flex: 3, - }, - { - header: gettext("Created At"), - dataIndex: "created_at", - renderer: PBS.Utils.render_optional_timestamp, - flex: 4, - }, - ], -}); +Ext.define("PBS.D2DManagement.TokenPanel", { + extend: "Ext.grid.Panel", + alias: "widget.pbsDiskTokenPanel", + + controller: { + xclass: "Ext.app.ViewController", + + onAdd: function () { + let me = this; + Ext.create("PBS.D2DManagement.TokenEditWindow", { + listeners: { + destroy: function () { + me.reload(); + }, + }, + }).show(); + }, + + onCopy: async function () { + let me = this; + let view = me.getView(); + let selection = view.getSelection(); + if (!selection || selection.length < 1) { + return; + } + + let token = selection[0].data.token; + Ext.create("Ext.window.Window", { + modal: true, + width: 600, + title: gettext("Bootstrap Token"), + layout: "form", + bodyPadding: "10 0", + items: [ + { + xtype: "textfield", + inputId: "token", + value: token, + editable: false, + }, + ], + buttons: [ + { + xtype: "button", + iconCls: "fa fa-clipboard", + handler: async function (b) { + await navigator.clipboard.writeText(token); + }, + text: gettext("Copy"), + }, + { + text: gettext("Ok"), + handler: function () { + this.up("window").close(); + }, + }, + ], + }).show(); + }, + + reload: function () { + this.getView().getStore().rstore.load(); + }, + + stopStore: function () { + this.getView().getStore().rstore.stopUpdate(); + }, + + startStore: function () { + this.getView().getStore().rstore.startUpdate(); + }, + + render_valid: function (value) { + if (value.toString() == "false") { + icon = "check good"; + text = "Valid"; + } else { + icon = "times critical"; + text = "Invalid"; + } + + return ` ${text}`; + }, + + init: function (view) { + Proxmox.Utils.monStoreErrors(view, view.getStore().rstore); + }, + }, + + listeners: { + beforedestroy: "stopStore", + deactivate: "stopStore", + activate: "startStore", + itemdblclick: "onCopy", + }, + + store: { + type: "diff", + rstore: { + type: "update", + storeid: "proxmox-agent-tokens", + model: "pbs-model-tokens", + proxy: { + type: "proxmox", + url: pbsPlusBaseUrl + "/api2/json/d2d/token", + }, + }, + sorters: "name", + }, + + features: [], + + tbar: [ + { + text: gettext("Generate Token"), + xtype: "proxmoxButton", + handler: "onAdd", + selModel: false, + }, + "-", + { + text: gettext("Copy Token"), + xtype: "proxmoxButton", + handler: "onCopy", + disabled: true, + }, + { + text: gettext("Revoke Token"), + xtype: "proxmoxStdRemoveButton", + baseurl: pbsPlusBaseUrl + "/api2/extjs/config/d2d-token", + getUrl: (rec) => + pbsPlusBaseUrl + + `/api2/extjs/config/d2d-token/${encodeURIComponent(encodePathValue(rec.getId()))}`, + callback: "reload", + }, + ], + columns: [ + { + text: gettext("Token"), + dataIndex: "token", + flex: 1, + }, + { + text: gettext("Comment"), + dataIndex: "comment", + flex: 2, + }, + { + header: gettext("Validity"), + dataIndex: "revoked", + renderer: "render_valid", + flex: 3, + }, + { + header: gettext("Created At"), + dataIndex: "created_at", + renderer: PBS.Utils.render_optional_timestamp, + flex: 4, + }, + ], +}); diff --git a/internal/proxy/views/d2d_backup/selectors/exclusions.js b/internal/proxy/views/d2d_backup/selectors/exclusions.js index 5eca21a..6a270fc 100644 --- a/internal/proxy/views/d2d_backup/selectors/exclusions.js +++ b/internal/proxy/views/d2d_backup/selectors/exclusions.js @@ -1,47 +1,47 @@ -Ext.define("PBS.form.D2DExclusionSelector", { - extend: "Proxmox.form.ComboGrid", - alias: "widget.pbsD2DExclusionSelector", - - allowBlank: false, - autoSelect: false, - - displayField: "name", - valueField: "name", - value: null, - - store: { - proxy: { - type: "proxmox", - url: pbsPlusBaseUrl + "/api2/json/d2d/exclusion", - }, - autoLoad: true, - sorters: "name", - }, - - listConfig: { - width: 450, - columns: [ - { - text: gettext("Path"), - dataIndex: "path", - sortable: true, - flex: 3, - renderer: Ext.String.htmlEncode, - }, - ], - }, - - initComponent: function () { - let me = this; - - if (me.changer) { - me.store.proxy.extraParams = { - changer: me.changer, - }; - } else { - me.store.proxy.extraParams = {}; - } - - me.callParent(); - }, -}); +Ext.define("PBS.form.D2DExclusionSelector", { + extend: "Proxmox.form.ComboGrid", + alias: "widget.pbsD2DExclusionSelector", + + allowBlank: false, + autoSelect: false, + + displayField: "name", + valueField: "name", + value: null, + + store: { + proxy: { + type: "proxmox", + url: pbsPlusBaseUrl + "/api2/json/d2d/exclusion", + }, + autoLoad: true, + sorters: "name", + }, + + listConfig: { + width: 450, + columns: [ + { + text: gettext("Path"), + dataIndex: "path", + sortable: true, + flex: 3, + renderer: Ext.String.htmlEncode, + }, + ], + }, + + initComponent: function () { + let me = this; + + if (me.changer) { + me.store.proxy.extraParams = { + changer: me.changer, + }; + } else { + me.store.proxy.extraParams = {}; + } + + me.callParent(); + }, +}); diff --git a/internal/proxy/views/d2d_backup/selectors/targets.js b/internal/proxy/views/d2d_backup/selectors/targets.js index 000e4fb..cb84dc5 100644 --- a/internal/proxy/views/d2d_backup/selectors/targets.js +++ b/internal/proxy/views/d2d_backup/selectors/targets.js @@ -1,61 +1,61 @@ -Ext.define("PBS.form.D2DTargetSelector", { - extend: "Proxmox.form.ComboGrid", - alias: "widget.pbsD2DTargetSelector", - - allowBlank: false, - autoSelect: false, - - displayField: "name", - valueField: "name", - value: null, - - store: { - proxy: { - type: "proxmox", - url: pbsPlusBaseUrl + "/api2/json/d2d/target", - }, - autoLoad: true, - sorters: "name", - }, - - listConfig: { - width: 450, - columns: [ - { - text: gettext("Name"), - dataIndex: "name", - sortable: true, - flex: 3, - renderer: Ext.String.htmlEncode, - }, - { - text: "Path", - dataIndex: "path", - sortable: true, - flex: 3, - renderer: Ext.String.htmlEncode, - }, - { - text: "Type", - dataIndex: "drive_type", - sortable: true, - flex: 3, - renderer: Ext.String.htmlEncode, - }, - ], - }, - - initComponent: function () { - let me = this; - - if (me.changer) { - me.store.proxy.extraParams = { - changer: me.changer, - }; - } else { - me.store.proxy.extraParams = {}; - } - - me.callParent(); - }, -}); +Ext.define("PBS.form.D2DTargetSelector", { + extend: "Proxmox.form.ComboGrid", + alias: "widget.pbsD2DTargetSelector", + + allowBlank: false, + autoSelect: false, + + displayField: "name", + valueField: "name", + value: null, + + store: { + proxy: { + type: "proxmox", + url: pbsPlusBaseUrl + "/api2/json/d2d/target", + }, + autoLoad: true, + sorters: "name", + }, + + listConfig: { + width: 450, + columns: [ + { + text: gettext("Name"), + dataIndex: "name", + sortable: true, + flex: 3, + renderer: Ext.String.htmlEncode, + }, + { + text: "Path", + dataIndex: "path", + sortable: true, + flex: 3, + renderer: Ext.String.htmlEncode, + }, + { + text: "Type", + dataIndex: "drive_type", + sortable: true, + flex: 3, + renderer: Ext.String.htmlEncode, + }, + ], + }, + + initComponent: function () { + let me = this; + + if (me.changer) { + me.store.proxy.extraParams = { + changer: me.changer, + }; + } else { + me.store.proxy.extraParams = {}; + } + + me.callParent(); + }, +}); diff --git a/internal/proxy/views/d2d_backup/selectors/tokens.js b/internal/proxy/views/d2d_backup/selectors/tokens.js index c360ee4..9fe4ba3 100644 --- a/internal/proxy/views/d2d_backup/selectors/tokens.js +++ b/internal/proxy/views/d2d_backup/selectors/tokens.js @@ -1,54 +1,54 @@ -Ext.define("PBS.form.D2DTokenSelector", { - extend: "Proxmox.form.ComboGrid", - alias: "widget.pbsD2DTokenSelector", - - allowBlank: false, - autoSelect: false, - - displayField: "name", - valueField: "name", - value: null, - - store: { - proxy: { - type: "proxmox", - url: pbsPlusBaseUrl + "/api2/json/d2d/token", - }, - autoLoad: true, - sorters: "name", - }, - - listConfig: { - width: 450, - columns: [ - { - text: gettext("Token"), - dataIndex: "token", - sortable: true, - flex: 3, - renderer: Ext.String.htmlEncode, - }, - { - text: "Comment", - dataIndex: "comment", - sortable: true, - flex: 3, - renderer: Ext.String.htmlEncode, - }, - ], - }, - - initComponent: function () { - let me = this; - - if (me.changer) { - me.store.proxy.extraParams = { - changer: me.changer, - }; - } else { - me.store.proxy.extraParams = {}; - } - - me.callParent(); - }, -}); +Ext.define("PBS.form.D2DTokenSelector", { + extend: "Proxmox.form.ComboGrid", + alias: "widget.pbsD2DTokenSelector", + + allowBlank: false, + autoSelect: false, + + displayField: "name", + valueField: "name", + value: null, + + store: { + proxy: { + type: "proxmox", + url: pbsPlusBaseUrl + "/api2/json/d2d/token", + }, + autoLoad: true, + sorters: "name", + }, + + listConfig: { + width: 450, + columns: [ + { + text: gettext("Token"), + dataIndex: "token", + sortable: true, + flex: 3, + renderer: Ext.String.htmlEncode, + }, + { + text: "Comment", + dataIndex: "comment", + sortable: true, + flex: 3, + renderer: Ext.String.htmlEncode, + }, + ], + }, + + initComponent: function () { + let me = this; + + if (me.changer) { + me.store.proxy.extraParams = { + changer: me.changer, + }; + } else { + me.store.proxy.extraParams = {}; + } + + me.callParent(); + }, +}); diff --git a/internal/proxy/views/d2d_backup/windows/backup.js b/internal/proxy/views/d2d_backup/windows/backup.js index 7e3811f..6fdf77c 100644 --- a/internal/proxy/views/d2d_backup/windows/backup.js +++ b/internal/proxy/views/d2d_backup/windows/backup.js @@ -1,71 +1,74 @@ -Ext.define("PBS.D2DManagement.BackupWindow", { - extend: "Proxmox.window.Edit", - mixins: ["Proxmox.Mixin.CBind"], - - id: undefined, - - cbindData: function (config) { - let me = this; - return { - warning: Ext.String.format( - gettext("Manually start backup job '{0}'?"), - me.id, - ), - id: me.id, - }; - }, - - title: gettext("Backup"), - url: pbsPlusBaseUrl + `/api2/extjs/d2d/backup`, - showProgress: true, - submitUrl: function (url, values) { - let id = values.id; - delete values.id; - return `${url}/${encodePathValue(id)}`; - }, - - layout: "hbox", - width: 400, - method: "POST", - isCreate: true, - submitText: gettext("Start Backup"), - items: [ - { - xtype: "container", - padding: 0, - layout: { - type: "hbox", - align: "stretch", - }, - items: [ - { - xtype: "component", - cls: [ - Ext.baseCSSPrefix + "message-box-icon", - Ext.baseCSSPrefix + "message-box-question", - Ext.baseCSSPrefix + "dlg-icon", - ], - }, - { - xtype: "container", - flex: 1, - items: [ - { - xtype: "displayfield", - cbind: { - value: "{warning}", - }, - }, - { - xtype: "hidden", - name: "id", - cbind: { - value: "{id}", - }, - }, - ], - }, - ], - }, - ], -}); +Ext.define("PBS.D2DManagement.BackupWindow", { + extend: "Proxmox.window.Edit", + mixins: ["Proxmox.Mixin.CBind"], + + id: undefined, + + cbindData: function (config) { + let me = this; + return { + warning: Ext.String.format( + gettext("Manually start backup job '{0}'?"), + me.id, + ), + id: me.id, + }; + }, + + title: gettext("Backup"), + url: pbsPlusBaseUrl + `/api2/extjs/d2d/backup`, + showProgress: true, + submitUrl: function (url, values) { + let id = values.id; + delete values.id; + return `${url}/${encodePathValue(id)}`; + }, + submitOptions: { + timeout: 120000, + }, + + layout: "hbox", + width: 400, + method: "POST", + isCreate: true, + submitText: gettext("Start Backup"), + items: [ + { + xtype: "container", + padding: 0, + layout: { + type: "hbox", + align: "stretch", + }, + items: [ + { + xtype: "component", + cls: [ + Ext.baseCSSPrefix + "message-box-icon", + Ext.baseCSSPrefix + "message-box-question", + Ext.baseCSSPrefix + "dlg-icon", + ], + }, + { + xtype: "container", + flex: 1, + items: [ + { + xtype: "displayfield", + cbind: { + value: "{warning}", + }, + }, + { + xtype: "hidden", + name: "id", + cbind: { + value: "{id}", + }, + }, + ], + }, + ], + }, + ], +}); diff --git a/internal/proxy/views/d2d_backup/windows/exclusion.js b/internal/proxy/views/d2d_backup/windows/exclusion.js index 601b84f..4944f01 100644 --- a/internal/proxy/views/d2d_backup/windows/exclusion.js +++ b/internal/proxy/views/d2d_backup/windows/exclusion.js @@ -1,44 +1,44 @@ -Ext.define("PBS.D2DManagement.ExclusionEditWindow", { - extend: "Proxmox.window.Edit", - alias: "widget.pbsExclusionEditWindow", - mixins: ["Proxmox.Mixin.CBind"], - - isCreate: true, - isAdd: true, - subject: "Disk Backup Global Path Exclusion", - cbindData: function (initialConfig) { - let me = this; - - let contentid = initialConfig.contentid; - let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/d2d-exclusion"; - - me.isCreate = !contentid; - me.url = contentid - ? `${baseurl}/${encodeURIComponent(encodePathValue(contentid))}` - : baseurl; - me.method = contentid ? "PUT" : "POST"; - - return {}; - }, - - items: [ - { - fieldLabel: gettext("Path"), - name: "path", - xtype: "pmxDisplayEditField", - renderer: Ext.htmlEncode, - allowBlank: false, - cbind: { - editable: "{isCreate}", - }, - }, - { - fieldLabel: gettext("Comment"), - xtype: "proxmoxtextfield", - name: "comment", - cbind: { - deleteEmpty: "{!isCreate}", - }, - }, - ], -}); +Ext.define("PBS.D2DManagement.ExclusionEditWindow", { + extend: "Proxmox.window.Edit", + alias: "widget.pbsExclusionEditWindow", + mixins: ["Proxmox.Mixin.CBind"], + + isCreate: true, + isAdd: true, + subject: "Disk Backup Global Path Exclusion", + cbindData: function (initialConfig) { + let me = this; + + let contentid = initialConfig.contentid; + let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/d2d-exclusion"; + + me.isCreate = !contentid; + me.url = contentid + ? `${baseurl}/${encodeURIComponent(encodePathValue(contentid))}` + : baseurl; + me.method = contentid ? "PUT" : "POST"; + + return {}; + }, + + items: [ + { + fieldLabel: gettext("Path"), + name: "path", + xtype: "pmxDisplayEditField", + renderer: Ext.htmlEncode, + allowBlank: false, + cbind: { + editable: "{isCreate}", + }, + }, + { + fieldLabel: gettext("Comment"), + xtype: "proxmoxtextfield", + name: "comment", + cbind: { + deleteEmpty: "{!isCreate}", + }, + }, + ], +}); diff --git a/internal/proxy/views/d2d_backup/windows/job.js b/internal/proxy/views/d2d_backup/windows/job.js index d4547fb..aa378f6 100644 --- a/internal/proxy/views/d2d_backup/windows/job.js +++ b/internal/proxy/views/d2d_backup/windows/job.js @@ -1,130 +1,130 @@ -Ext.define("PBS.D2DManagement.BackupJobEdit", { - extend: "Proxmox.window.Edit", - alias: "widget.pbsDiskBackupJobEdit", - mixins: ["Proxmox.Mixin.CBind"], - - userid: undefined, - - isAdd: true, - - subject: gettext("Disk Backup Job"), - - fieldDefaults: { labelWidth: 120 }, - - bodyPadding: 0, - - cbindData: function (initialConfig) { - let me = this; - - let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/disk-backup-job"; - let id = initialConfig.id; - - me.isCreate = !id; - me.url = id ? `${baseurl}/${encodePathValue(id)}` : baseurl; - me.method = id ? "PUT" : "POST"; - me.autoLoad = !!id; - me.scheduleValue = id ? null : "daily"; - me.authid = id ? null : Proxmox.UserName; - me.editDatastore = me.datastore === undefined && me.isCreate; - return {}; - }, - - viewModel: {}, - - initComponent: function () { - let me = this; - me.callParent(); - }, - - items: { - xtype: "tabpanel", - bodyPadding: 10, - border: 0, - items: [ - { - title: gettext("Options"), - xtype: "inputpanel", - onGetValues: function (values) { - let me = this; - - if (me.isCreate) { - delete values.delete; - } - - return values; - }, - cbind: { - isCreate: "{isCreate}", // pass it through - }, - column1: [ - { - xtype: "pmxDisplayEditField", - name: "id", - fieldLabel: gettext("Job ID"), - renderer: Ext.htmlEncode, - allowBlank: false, - cbind: { - editable: "{isCreate}", - }, - }, - { - xtype: "pbsD2DTargetSelector", - fieldLabel: "Target", - name: "target", - }, - { - xtype: "proxmoxtextfield", - fieldLabel: gettext("Subpath"), - emptyText: gettext("/"), - name: "subpath", - }, - { - xtype: "pbsDataStoreSelector", - fieldLabel: gettext("Local Datastore"), - name: "store", - }, - { - xtype: "proxmoxtextfield", - fieldLabel: gettext("Namespace"), - emptyText: gettext("Root"), - name: "ns", - }, - ], - - column2: [ - { - fieldLabel: gettext("Schedule"), - xtype: "pbsCalendarEvent", - name: "schedule", - emptyText: gettext("none (disabled)"), - cbind: { - deleteEmpty: "{!isCreate}", - value: "{scheduleValue}", - }, - }, - ], - - columnB: [ - { - fieldLabel: gettext("Comment"), - xtype: "proxmoxtextfield", - name: "comment", - cbind: { - deleteEmpty: "{!isCreate}", - }, - }, - { - xtype: "textarea", - name: "rawexclusions", - height: "100%", - fieldLabel: gettext("Exclusions"), - value: "", - emptyText: gettext( - "Newline delimited list of exclusions following the .pxarexclude patterns.", - ), - }, - ], - }, - ], - }, -}); +Ext.define("PBS.D2DManagement.BackupJobEdit", { + extend: "Proxmox.window.Edit", + alias: "widget.pbsDiskBackupJobEdit", + mixins: ["Proxmox.Mixin.CBind"], + + userid: undefined, + + isAdd: true, + + subject: gettext("Disk Backup Job"), + + fieldDefaults: { labelWidth: 120 }, + + bodyPadding: 0, + + cbindData: function (initialConfig) { + let me = this; + + let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/disk-backup-job"; + let id = initialConfig.id; + + me.isCreate = !id; + me.url = id ? `${baseurl}/${encodePathValue(id)}` : baseurl; + me.method = id ? "PUT" : "POST"; + me.autoLoad = !!id; + me.scheduleValue = id ? null : "daily"; + me.authid = id ? null : Proxmox.UserName; + me.editDatastore = me.datastore === undefined && me.isCreate; + return {}; + }, + + viewModel: {}, + + initComponent: function () { + let me = this; + me.callParent(); + }, + + items: { + xtype: "tabpanel", + bodyPadding: 10, + border: 0, + items: [ + { + title: gettext("Options"), + xtype: "inputpanel", + onGetValues: function (values) { + let me = this; + + if (me.isCreate) { + delete values.delete; + } + + return values; + }, + cbind: { + isCreate: "{isCreate}", // pass it through + }, + column1: [ + { + xtype: "pmxDisplayEditField", + name: "id", + fieldLabel: gettext("Job ID"), + renderer: Ext.htmlEncode, + allowBlank: false, + cbind: { + editable: "{isCreate}", + }, + }, + { + xtype: "pbsD2DTargetSelector", + fieldLabel: "Target", + name: "target", + }, + { + xtype: "proxmoxtextfield", + fieldLabel: gettext("Subpath"), + emptyText: gettext("/"), + name: "subpath", + }, + { + xtype: "pbsDataStoreSelector", + fieldLabel: gettext("Local Datastore"), + name: "store", + }, + { + xtype: "proxmoxtextfield", + fieldLabel: gettext("Namespace"), + emptyText: gettext("Root"), + name: "ns", + }, + ], + + column2: [ + { + fieldLabel: gettext("Schedule"), + xtype: "pbsCalendarEvent", + name: "schedule", + emptyText: gettext("none (disabled)"), + cbind: { + deleteEmpty: "{!isCreate}", + value: "{scheduleValue}", + }, + }, + ], + + columnB: [ + { + fieldLabel: gettext("Comment"), + xtype: "proxmoxtextfield", + name: "comment", + cbind: { + deleteEmpty: "{!isCreate}", + }, + }, + { + xtype: "textarea", + name: "rawexclusions", + height: "100%", + fieldLabel: gettext("Exclusions"), + value: "", + emptyText: gettext( + "Newline delimited list of exclusions following the .pxarexclude patterns.", + ), + }, + ], + }, + ], + }, +}); diff --git a/internal/proxy/views/d2d_backup/windows/target.js b/internal/proxy/views/d2d_backup/windows/target.js index d89e807..c1dce91 100644 --- a/internal/proxy/views/d2d_backup/windows/target.js +++ b/internal/proxy/views/d2d_backup/windows/target.js @@ -1,46 +1,46 @@ -Ext.define("PBS.D2DManagement.TargetEditWindow", { - extend: "Proxmox.window.Edit", - alias: "widget.pbsTargetEditWindow", - mixins: ["Proxmox.Mixin.CBind"], - - isCreate: true, - isAdd: true, - subject: "Disk Backup Target", - cbindData: function (initialConfig) { - let me = this; - - let contentid = initialConfig.contentid; - let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/d2d-target"; - - me.isCreate = !contentid; - me.url = contentid - ? `${baseurl}/${encodeURIComponent(encodePathValue(contentid))}` - : baseurl; - me.method = contentid ? "PUT" : "POST"; - - return {}; - }, - - items: [ - { - fieldLabel: gettext("Name"), - name: "name", - xtype: "pmxDisplayEditField", - renderer: Ext.htmlEncode, - allowBlank: false, - cbind: { - editable: "{isCreate}", - }, - }, - { - fieldLabel: gettext("Path"), - name: "path", - xtype: "pmxDisplayEditField", - renderer: Ext.htmlEncode, - allowBlank: false, - cbind: { - editable: "{isCreate}", - }, - }, - ], -}); +Ext.define("PBS.D2DManagement.TargetEditWindow", { + extend: "Proxmox.window.Edit", + alias: "widget.pbsTargetEditWindow", + mixins: ["Proxmox.Mixin.CBind"], + + isCreate: true, + isAdd: true, + subject: "Disk Backup Target", + cbindData: function (initialConfig) { + let me = this; + + let contentid = initialConfig.contentid; + let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/d2d-target"; + + me.isCreate = !contentid; + me.url = contentid + ? `${baseurl}/${encodeURIComponent(encodePathValue(contentid))}` + : baseurl; + me.method = contentid ? "PUT" : "POST"; + + return {}; + }, + + items: [ + { + fieldLabel: gettext("Name"), + name: "name", + xtype: "pmxDisplayEditField", + renderer: Ext.htmlEncode, + allowBlank: false, + cbind: { + editable: "{isCreate}", + }, + }, + { + fieldLabel: gettext("Path"), + name: "path", + xtype: "pmxDisplayEditField", + renderer: Ext.htmlEncode, + allowBlank: false, + cbind: { + editable: "{isCreate}", + }, + }, + ], +}); diff --git a/internal/proxy/views/d2d_backup/windows/token.js b/internal/proxy/views/d2d_backup/windows/token.js index 84f1b17..002785b 100644 --- a/internal/proxy/views/d2d_backup/windows/token.js +++ b/internal/proxy/views/d2d_backup/windows/token.js @@ -1,36 +1,36 @@ -Ext.define("PBS.D2DManagement.TokenEditWindow", { - extend: "Proxmox.window.Edit", - alias: "widget.pbsTokenEditWindow", - mixins: ["Proxmox.Mixin.CBind"], - - isCreate: true, - isAdd: true, - subject: "Agent Bootstrap Token", - cbindData: function (initialConfig) { - let me = this; - - let contentid = initialConfig.contentid; - let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/d2d-token"; - - me.isCreate = !contentid; - me.url = contentid - ? `${baseurl}/${encodeURIComponent(encodePathValue(contentid))}` - : baseurl; - me.method = contentid ? "PUT" : "POST"; - - return {}; - }, - - items: [ - { - fieldLabel: gettext("Comment"), - name: "comment", - xtype: "pmxDisplayEditField", - renderer: Ext.htmlEncode, - allowBlank: false, - cbind: { - editable: "{isCreate}", - }, - }, - ], -}); +Ext.define("PBS.D2DManagement.TokenEditWindow", { + extend: "Proxmox.window.Edit", + alias: "widget.pbsTokenEditWindow", + mixins: ["Proxmox.Mixin.CBind"], + + isCreate: true, + isAdd: true, + subject: "Agent Bootstrap Token", + cbindData: function (initialConfig) { + let me = this; + + let contentid = initialConfig.contentid; + let baseurl = pbsPlusBaseUrl + "/api2/extjs/config/d2d-token"; + + me.isCreate = !contentid; + me.url = contentid + ? `${baseurl}/${encodeURIComponent(encodePathValue(contentid))}` + : baseurl; + me.method = contentid ? "PUT" : "POST"; + + return {}; + }, + + items: [ + { + fieldLabel: gettext("Comment"), + name: "comment", + xtype: "pmxDisplayEditField", + renderer: Ext.htmlEncode, + allowBlank: false, + cbind: { + editable: "{isCreate}", + }, + }, + ], +}); diff --git a/internal/proxy/views/navigation.js b/internal/proxy/views/navigation.js index bf9a401..d1388bf 100644 --- a/internal/proxy/views/navigation.js +++ b/internal/proxy/views/navigation.js @@ -1,122 +1,122 @@ -Ext.define("PBS.store.NavigationStore", { - extend: "Ext.data.TreeStore", - - storeId: "NavigationStore", - - root: { - expanded: true, - children: [ - { - text: gettext("Dashboard"), - iconCls: "fa fa-tachometer", - path: "pbsDashboard", - leaf: true, - }, - { - text: gettext("Notes"), - iconCls: "fa fa-sticky-note-o", - path: "pbsNodeNotes", - leaf: true, - }, - { - text: gettext("Configuration"), - iconCls: "fa fa-gears", - path: "pbsSystemConfiguration", - expanded: true, - children: [ - { - text: gettext("Access Control"), - iconCls: "fa fa-key", - path: "pbsAccessControlPanel", - leaf: true, - }, - { - text: gettext("Remotes"), - iconCls: "fa fa-server", - path: "pbsRemoteView", - leaf: true, - }, - { - text: gettext("Traffic Control"), - iconCls: "fa fa-signal fa-rotate-90", - path: "pbsTrafficControlView", - leaf: true, - }, - { - text: gettext("Certificates"), - iconCls: "fa fa-certificate", - path: "pbsCertificateConfiguration", - leaf: true, - }, - { - text: gettext("Notifications"), - iconCls: "fa fa-bell-o", - path: "pbsNotificationConfigView", - leaf: true, - }, - { - text: gettext("Subscription"), - iconCls: "fa fa-support", - path: "pbsSubscription", - leaf: true, - }, - ], - }, - { - text: gettext("Administration"), - iconCls: "fa fa-wrench", - path: "pbsServerAdministration", - expanded: true, - leaf: false, - children: [ - { - text: gettext("Shell"), - iconCls: "fa fa-terminal", - path: "pbsXtermJsConsole", - leaf: true, - }, - { - text: gettext("Storage / Disks"), - iconCls: "fa fa-hdd-o", - path: "pbsStorageAndDiskPanel", - leaf: true, - }, - ], - }, - { - text: "Disk Backup", - iconCls: "fa fa-hdd-o", - id: "backup_targets", - path: "pbsD2DManagement", - expanded: true, - children: [], - }, - { - text: "Tape Backup", - iconCls: "pbs-icon-tape", - id: "tape_management", - path: "pbsTapeManagement", - expanded: true, - children: [], - }, - { - text: gettext("Datastore"), - iconCls: "fa fa-archive", - id: "datastores", - path: "pbsDataStores", - expanded: true, - expandable: false, - leaf: false, - children: [ - { - text: gettext("Add Datastore"), - iconCls: "fa fa-plus-circle", - leaf: true, - id: "addbutton", - virtualEntry: true, - }, - ], - }, - ], - }, -}); +Ext.define("PBS.store.NavigationStore", { + extend: "Ext.data.TreeStore", + + storeId: "NavigationStore", + + root: { + expanded: true, + children: [ + { + text: gettext("Dashboard"), + iconCls: "fa fa-tachometer", + path: "pbsDashboard", + leaf: true, + }, + { + text: gettext("Notes"), + iconCls: "fa fa-sticky-note-o", + path: "pbsNodeNotes", + leaf: true, + }, + { + text: gettext("Configuration"), + iconCls: "fa fa-gears", + path: "pbsSystemConfiguration", + expanded: true, + children: [ + { + text: gettext("Access Control"), + iconCls: "fa fa-key", + path: "pbsAccessControlPanel", + leaf: true, + }, + { + text: gettext("Remotes"), + iconCls: "fa fa-server", + path: "pbsRemoteView", + leaf: true, + }, + { + text: gettext("Traffic Control"), + iconCls: "fa fa-signal fa-rotate-90", + path: "pbsTrafficControlView", + leaf: true, + }, + { + text: gettext("Certificates"), + iconCls: "fa fa-certificate", + path: "pbsCertificateConfiguration", + leaf: true, + }, + { + text: gettext("Notifications"), + iconCls: "fa fa-bell-o", + path: "pbsNotificationConfigView", + leaf: true, + }, + { + text: gettext("Subscription"), + iconCls: "fa fa-support", + path: "pbsSubscription", + leaf: true, + }, + ], + }, + { + text: gettext("Administration"), + iconCls: "fa fa-wrench", + path: "pbsServerAdministration", + expanded: true, + leaf: false, + children: [ + { + text: gettext("Shell"), + iconCls: "fa fa-terminal", + path: "pbsXtermJsConsole", + leaf: true, + }, + { + text: gettext("Storage / Disks"), + iconCls: "fa fa-hdd-o", + path: "pbsStorageAndDiskPanel", + leaf: true, + }, + ], + }, + { + text: "Disk Backup", + iconCls: "fa fa-hdd-o", + id: "backup_targets", + path: "pbsD2DManagement", + expanded: true, + children: [], + }, + { + text: "Tape Backup", + iconCls: "pbs-icon-tape", + id: "tape_management", + path: "pbsTapeManagement", + expanded: true, + children: [], + }, + { + text: gettext("Datastore"), + iconCls: "fa fa-archive", + id: "datastores", + path: "pbsDataStores", + expanded: true, + expandable: false, + leaf: false, + children: [ + { + text: gettext("Add Datastore"), + iconCls: "fa fa-plus-circle", + leaf: true, + id: "addbutton", + virtualEntry: true, + }, + ], + }, + ], + }, +}); diff --git a/internal/store/constants/constants.go b/internal/store/constants/constants.go index 06130b5..9844948 100644 --- a/internal/store/constants/constants.go +++ b/internal/store/constants/constants.go @@ -1,11 +1,11 @@ -package constants - -const ( - ProxyTargetURL = "https://127.0.0.1:8007" // The target server URL - ModifiedFilePath = "/js/proxmox-backup-gui.js" // The specific JS file to modify - CertFile = "/etc/proxmox-backup/proxy.pem" // Path to generated SSL certificate - KeyFile = "/etc/proxmox-backup/proxy.key" // Path to generated private key - TimerBasePath = "/lib/systemd/system" - DbBasePath = "/var/lib/proxmox-backup" - AgentMountBasePath = "/mnt/pbs-plus-mounts" -) +package constants + +const ( + ProxyTargetURL = "https://127.0.0.1:8007" // The target server URL + ModifiedFilePath = "/js/proxmox-backup-gui.js" // The specific JS file to modify + CertFile = "/etc/proxmox-backup/proxy.pem" // Path to generated SSL certificate + KeyFile = "/etc/proxmox-backup/proxy.key" // Path to generated private key + TimerBasePath = "/lib/systemd/system" + DbBasePath = "/var/lib/proxmox-backup" + AgentMountBasePath = "/mnt/pbs-plus-mounts" +) diff --git a/internal/store/constants/default_exclusions.go b/internal/store/constants/default_exclusions.go index e258876..61562bb 100644 --- a/internal/store/constants/default_exclusions.go +++ b/internal/store/constants/default_exclusions.go @@ -1,61 +1,61 @@ -package constants - -var DefaultExclusions = []string{ - `hiberfil.sys`, - `pagefile.sys`, - `swapfile.sys`, - `autoexec.bat`, - `Config.Msi`, - `Documents and Settings`, - `Recycled`, - `Recycler`, - `$$Recycle.Bin`, - `Recovery`, - `Program Files`, - `Program Files (x86)`, - `ProgramData`, - `PerfLogs`, - `Windows`, - `Windows.old`, - `$$WINDOWS.~BT`, - `$$WinREAgent`, - `$RECYCLE.BIN`, - `$WinREAgent`, - `System Volume Information`, - `Temporary Internet Files`, - `Microsoft/Windows/Recent`, - `Microsoft/**/RecoveryStore`, - `Microsoft/**/Windows/**/*.edb`, - `Microsoft/**/Windows/**/*.log`, - `Microsoft/**/Windows/Cookies*`, - `Microsoft/**/Logs*`, - `Users/Public/AccountPictures`, - `I386`, - `Internet Explorer`, - `MSOCache`, - `NTUSER*`, - `UsrClass.dat`, - `Thumbs.db`, - `AppData/Local/Temp*`, - `AppData/Temp*`, - `Local Settings/Temp*`, - `**.tmp`, - `AppData/**/*cache*`, - `AppData/**/Crash Reports`, - `AppData/Local/Apple Computer/Mobile Sync`, - `AppData/Local/Comms/UnistoreDB`, - `AppData/Local/ElevatedDiagnostics`, - `AppData/Local/Microsoft/Windows/Explorer`, - `AppData/Local/Microsoft/Windows/INetCache`, - `AppData/Local/Microsoft/Windows/UPPS`, - `AppData/Local/Microsoft/Windows/WebCache`, - `AppData/Local/Microsoft/Windows Store`, - `AppData/Local/Packages`, - `Application Data/Apple Computer/Mobile Sync`, - `Application Data/Application Data*`, - `iPhoto Library/iPod Photo Cache`, - `cookies.sqlite-*`, - `permissions.sqlite-*`, - `Local Settings/History`, - `OneDrive/.849C9593-D756-4E56-8D6E-42412F2A707B`, -} +package constants + +var DefaultExclusions = []string{ + `hiberfil.sys`, + `pagefile.sys`, + `swapfile.sys`, + `autoexec.bat`, + `Config.Msi`, + `Documents and Settings`, + `Recycled`, + `Recycler`, + `$$Recycle.Bin`, + `Recovery`, + `Program Files`, + `Program Files (x86)`, + `ProgramData`, + `PerfLogs`, + `Windows`, + `Windows.old`, + `$$WINDOWS.~BT`, + `$$WinREAgent`, + `$RECYCLE.BIN`, + `$WinREAgent`, + `System Volume Information`, + `Temporary Internet Files`, + `Microsoft/Windows/Recent`, + `Microsoft/**/RecoveryStore`, + `Microsoft/**/Windows/**/*.edb`, + `Microsoft/**/Windows/**/*.log`, + `Microsoft/**/Windows/Cookies*`, + `Microsoft/**/Logs*`, + `Users/Public/AccountPictures`, + `I386`, + `Internet Explorer`, + `MSOCache`, + `NTUSER*`, + `UsrClass.dat`, + `Thumbs.db`, + `AppData/Local/Temp*`, + `AppData/Temp*`, + `Local Settings/Temp*`, + `**.tmp`, + `AppData/**/*cache*`, + `AppData/**/Crash Reports`, + `AppData/Local/Apple Computer/Mobile Sync`, + `AppData/Local/Comms/UnistoreDB`, + `AppData/Local/ElevatedDiagnostics`, + `AppData/Local/Microsoft/Windows/Explorer`, + `AppData/Local/Microsoft/Windows/INetCache`, + `AppData/Local/Microsoft/Windows/UPPS`, + `AppData/Local/Microsoft/Windows/WebCache`, + `AppData/Local/Microsoft/Windows Store`, + `AppData/Local/Packages`, + `Application Data/Apple Computer/Mobile Sync`, + `Application Data/Application Data*`, + `iPhoto Library/iPod Photo Cache`, + `cookies.sqlite-*`, + `permissions.sqlite-*`, + `Local Settings/History`, + `OneDrive/.849C9593-D756-4E56-8D6E-42412F2A707B`, +} diff --git a/internal/store/constants/version.go b/internal/store/constants/version.go index 238ed6f..3d4a58f 100644 --- a/internal/store/constants/version.go +++ b/internal/store/constants/version.go @@ -1,3 +1,3 @@ -package constants - -var Version string +package constants + +var Version string diff --git a/internal/store/database/exclusions.go b/internal/store/database/exclusions.go index 943c0c2..17bf8a0 100644 --- a/internal/store/database/exclusions.go +++ b/internal/store/database/exclusions.go @@ -1,270 +1,270 @@ -//go:build linux - -package database - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - configLib "github.com/sonroyaalmerol/pbs-plus/internal/config" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" - "github.com/sonroyaalmerol/pbs-plus/internal/utils/pattern" -) - -func (database *Database) RegisterExclusionPlugin() { - plugin := &configLib.SectionPlugin[types.Exclusion]{ - TypeName: "exclusion", - FolderPath: database.paths["exclusions"], - Validate: func(config types.Exclusion) error { - if !pattern.IsValidPattern(config.Path) { - return fmt.Errorf("invalid exclusion pattern: %s", config.Path) - } - return nil - }, - } - - database.exclusionsConfig = configLib.NewSectionConfig(plugin) -} - -func (database *Database) CreateExclusion(exclusion types.Exclusion) error { - exclusion.Path = strings.ReplaceAll(exclusion.Path, "\\", "/") - - if !pattern.IsValidPattern(exclusion.Path) { - return fmt.Errorf("CreateExclusion: invalid path pattern -> %s", exclusion.Path) - } - - filename := "global" - if exclusion.JobID != "" { - filename = utils.EncodePath(exclusion.JobID) - } - - configPath := filepath.Join(database.paths["exclusions"], filename+".cfg") - - // Read existing exclusions - var configData *configLib.ConfigData[types.Exclusion] - existing, err := database.exclusionsConfig.Parse(configPath) - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("CreateExclusion: error reading existing config: %w", err) - } - - if existing != nil { - configData = existing - } else { - configData = &configLib.ConfigData[types.Exclusion]{ - Sections: make(map[string]*configLib.Section[types.Exclusion]), - Order: make([]string, 0), - FilePath: configPath, - } - } - - // Add new exclusion - sectionID := fmt.Sprintf("excl-%s", exclusion.Path) - configData.Sections[sectionID] = &configLib.Section[types.Exclusion]{ - Type: "exclusion", - ID: sectionID, - Properties: exclusion, - } - configData.Order = append(configData.Order, sectionID) - - if err := database.exclusionsConfig.Write(configData); err != nil { - return fmt.Errorf("CreateExclusion: error writing config: %w", err) - } - - return nil -} - -func (database *Database) GetAllJobExclusions(jobId string) ([]types.Exclusion, error) { - configPath := filepath.Join(database.paths["exclusions"], utils.EncodePath(jobId)+".cfg") - configData, err := database.exclusionsConfig.Parse(configPath) - if err != nil { - if os.IsNotExist(err) { - return []types.Exclusion{}, nil - } - return nil, fmt.Errorf("GetAllJobExclusions: error reading config: %w", err) - } - - var exclusions []types.Exclusion - seenPaths := make(map[string]bool) - - for _, sectionID := range configData.Order { - section := configData.Sections[sectionID] - if seenPaths[section.Properties.Path] { - continue // Skip duplicates - } - seenPaths[section.Properties.Path] = true - - exclusions = append(exclusions, section.Properties) - } - return exclusions, nil -} - -func (database *Database) GetAllGlobalExclusions() ([]types.Exclusion, error) { - configPath := filepath.Join(database.paths["exclusions"], "global.cfg") - configData, err := database.exclusionsConfig.Parse(configPath) - if err != nil { - if os.IsNotExist(err) { - return []types.Exclusion{}, nil - } - return nil, fmt.Errorf("GetAllGlobalExclusions: error reading config: %w", err) - } - - var exclusions []types.Exclusion - seenPaths := make(map[string]bool) - - for _, sectionID := range configData.Order { - section := configData.Sections[sectionID] - if seenPaths[section.Properties.Path] { - continue // Skip duplicates - } - seenPaths[section.Properties.Path] = true - - exclusions = append(exclusions, section.Properties) - } - return exclusions, nil -} - -func (database *Database) GetExclusion(path string) (*types.Exclusion, error) { - // Check global exclusions first - globalPath := filepath.Join(database.paths["exclusions"], "global.cfg") - if configData, err := database.exclusionsConfig.Parse(globalPath); err == nil { - sectionID := fmt.Sprintf("excl-%s", path) - if section, exists := configData.Sections[sectionID]; exists { - return §ion.Properties, nil - } - } - - // Check job-specific exclusions - files, err := os.ReadDir(database.paths["exclusions"]) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, fmt.Errorf("GetExclusion: error reading directory: %w", err) - } - - for _, file := range files { - if file.Name() == "global" { - continue - } - configPath := filepath.Join(database.paths["exclusions"], file.Name()) - configData, err := database.exclusionsConfig.Parse(configPath) - if err != nil { - continue - } - - sectionID := fmt.Sprintf("excl-%s", path) - if section, exists := configData.Sections[sectionID]; exists { - return §ion.Properties, nil - } - } - - return nil, fmt.Errorf("GetExclusion: exclusion not found for path: %s", path) -} - -func (database *Database) UpdateExclusion(exclusion types.Exclusion) error { - exclusion.Path = strings.ReplaceAll(exclusion.Path, "\\", "/") - if !pattern.IsValidPattern(exclusion.Path) { - return fmt.Errorf("UpdateExclusion: invalid path pattern -> %s", exclusion.Path) - } - - configPath := filepath.Join(database.paths["exclusions"], "global.cfg") - if exclusion.JobID != "" { - configPath = filepath.Join(database.paths["exclusions"], utils.EncodePath(exclusion.JobID)+".cfg") - } - - configData, err := database.exclusionsConfig.Parse(configPath) - if err != nil { - return fmt.Errorf("UpdateExclusion: error reading config: %w", err) - } - - sectionID := fmt.Sprintf("excl-%s", exclusion.Path) - _, exists := configData.Sections[sectionID] - if !exists { - return fmt.Errorf("UpdateExclusion: exclusion not found for path: %s", exclusion.Path) - } - - // Update properties - configData.Sections[sectionID].Properties = exclusion - return database.exclusionsConfig.Write(configData) -} - -func (database *Database) DeleteExclusion(path string) error { - path = strings.ReplaceAll(path, "\\", "/") - sectionID := fmt.Sprintf("excl-%s", path) - - // Try job-specific exclusions first - files, err := os.ReadDir(database.paths["exclusions"]) - if err != nil { - return fmt.Errorf("DeleteExclusion: error reading directory: %w", err) - } - - for _, file := range files { - if file.Name() == "global" { - continue - } - configPath := filepath.Join(database.paths["exclusions"], file.Name()) - configData, err := database.exclusionsConfig.Parse(configPath) - if err != nil { - continue - } - - if _, exists := configData.Sections[sectionID]; exists { - delete(configData.Sections, sectionID) - newOrder := make([]string, 0) - for _, id := range configData.Order { - if id != sectionID { - newOrder = append(newOrder, id) - } - } - configData.Order = newOrder - - // If the config is empty after deletion, remove the file - if len(configData.Sections) == 0 { - if err := os.Remove(configPath); err != nil { - return fmt.Errorf("DeleteExclusion: error removing empty config file: %w", err) - } - return nil - } - - // Otherwise write the updated config - if err := database.exclusionsConfig.Write(configData); err != nil { - return fmt.Errorf("DeleteExclusion: error writing config: %w", err) - } - return nil - } - } - - // Try global exclusion - globalPath := filepath.Join(database.paths["exclusions"], "global.cfg") - if configData, err := database.exclusionsConfig.Parse(globalPath); err == nil { - if _, exists := configData.Sections[sectionID]; exists { - delete(configData.Sections, sectionID) - newOrder := make([]string, 0) - for _, id := range configData.Order { - if id != sectionID { - newOrder = append(newOrder, id) - } - } - configData.Order = newOrder - - // If the global config is empty after deletion, remove the file - if len(configData.Sections) == 0 { - if err := os.Remove(globalPath); err != nil { - return fmt.Errorf("DeleteExclusion: error removing empty global config file: %w", err) - } - return nil - } - - // Otherwise write the updated config - if err := database.exclusionsConfig.Write(configData); err != nil { - return fmt.Errorf("DeleteExclusion: error writing global config: %w", err) - } - return nil - } - } - - return fmt.Errorf("DeleteExclusion: exclusion not found for path: %s", path) -} +//go:build linux + +package database + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + configLib "github.com/sonroyaalmerol/pbs-plus/internal/config" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" + "github.com/sonroyaalmerol/pbs-plus/internal/utils/pattern" +) + +func (database *Database) RegisterExclusionPlugin() { + plugin := &configLib.SectionPlugin[types.Exclusion]{ + TypeName: "exclusion", + FolderPath: database.paths["exclusions"], + Validate: func(config types.Exclusion) error { + if !pattern.IsValidPattern(config.Path) { + return fmt.Errorf("invalid exclusion pattern: %s", config.Path) + } + return nil + }, + } + + database.exclusionsConfig = configLib.NewSectionConfig(plugin) +} + +func (database *Database) CreateExclusion(exclusion types.Exclusion) error { + exclusion.Path = strings.ReplaceAll(exclusion.Path, "\\", "/") + + if !pattern.IsValidPattern(exclusion.Path) { + return fmt.Errorf("CreateExclusion: invalid path pattern -> %s", exclusion.Path) + } + + filename := "global" + if exclusion.JobID != "" { + filename = utils.EncodePath(exclusion.JobID) + } + + configPath := filepath.Join(database.paths["exclusions"], filename+".cfg") + + // Read existing exclusions + var configData *configLib.ConfigData[types.Exclusion] + existing, err := database.exclusionsConfig.Parse(configPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("CreateExclusion: error reading existing config: %w", err) + } + + if existing != nil { + configData = existing + } else { + configData = &configLib.ConfigData[types.Exclusion]{ + Sections: make(map[string]*configLib.Section[types.Exclusion]), + Order: make([]string, 0), + FilePath: configPath, + } + } + + // Add new exclusion + sectionID := fmt.Sprintf("excl-%s", exclusion.Path) + configData.Sections[sectionID] = &configLib.Section[types.Exclusion]{ + Type: "exclusion", + ID: sectionID, + Properties: exclusion, + } + configData.Order = append(configData.Order, sectionID) + + if err := database.exclusionsConfig.Write(configData); err != nil { + return fmt.Errorf("CreateExclusion: error writing config: %w", err) + } + + return nil +} + +func (database *Database) GetAllJobExclusions(jobId string) ([]types.Exclusion, error) { + configPath := filepath.Join(database.paths["exclusions"], utils.EncodePath(jobId)+".cfg") + configData, err := database.exclusionsConfig.Parse(configPath) + if err != nil { + if os.IsNotExist(err) { + return []types.Exclusion{}, nil + } + return nil, fmt.Errorf("GetAllJobExclusions: error reading config: %w", err) + } + + var exclusions []types.Exclusion + seenPaths := make(map[string]bool) + + for _, sectionID := range configData.Order { + section := configData.Sections[sectionID] + if seenPaths[section.Properties.Path] { + continue // Skip duplicates + } + seenPaths[section.Properties.Path] = true + + exclusions = append(exclusions, section.Properties) + } + return exclusions, nil +} + +func (database *Database) GetAllGlobalExclusions() ([]types.Exclusion, error) { + configPath := filepath.Join(database.paths["exclusions"], "global.cfg") + configData, err := database.exclusionsConfig.Parse(configPath) + if err != nil { + if os.IsNotExist(err) { + return []types.Exclusion{}, nil + } + return nil, fmt.Errorf("GetAllGlobalExclusions: error reading config: %w", err) + } + + var exclusions []types.Exclusion + seenPaths := make(map[string]bool) + + for _, sectionID := range configData.Order { + section := configData.Sections[sectionID] + if seenPaths[section.Properties.Path] { + continue // Skip duplicates + } + seenPaths[section.Properties.Path] = true + + exclusions = append(exclusions, section.Properties) + } + return exclusions, nil +} + +func (database *Database) GetExclusion(path string) (*types.Exclusion, error) { + // Check global exclusions first + globalPath := filepath.Join(database.paths["exclusions"], "global.cfg") + if configData, err := database.exclusionsConfig.Parse(globalPath); err == nil { + sectionID := fmt.Sprintf("excl-%s", path) + if section, exists := configData.Sections[sectionID]; exists { + return §ion.Properties, nil + } + } + + // Check job-specific exclusions + files, err := os.ReadDir(database.paths["exclusions"]) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("GetExclusion: error reading directory: %w", err) + } + + for _, file := range files { + if file.Name() == "global" { + continue + } + configPath := filepath.Join(database.paths["exclusions"], file.Name()) + configData, err := database.exclusionsConfig.Parse(configPath) + if err != nil { + continue + } + + sectionID := fmt.Sprintf("excl-%s", path) + if section, exists := configData.Sections[sectionID]; exists { + return §ion.Properties, nil + } + } + + return nil, fmt.Errorf("GetExclusion: exclusion not found for path: %s", path) +} + +func (database *Database) UpdateExclusion(exclusion types.Exclusion) error { + exclusion.Path = strings.ReplaceAll(exclusion.Path, "\\", "/") + if !pattern.IsValidPattern(exclusion.Path) { + return fmt.Errorf("UpdateExclusion: invalid path pattern -> %s", exclusion.Path) + } + + configPath := filepath.Join(database.paths["exclusions"], "global.cfg") + if exclusion.JobID != "" { + configPath = filepath.Join(database.paths["exclusions"], utils.EncodePath(exclusion.JobID)+".cfg") + } + + configData, err := database.exclusionsConfig.Parse(configPath) + if err != nil { + return fmt.Errorf("UpdateExclusion: error reading config: %w", err) + } + + sectionID := fmt.Sprintf("excl-%s", exclusion.Path) + _, exists := configData.Sections[sectionID] + if !exists { + return fmt.Errorf("UpdateExclusion: exclusion not found for path: %s", exclusion.Path) + } + + // Update properties + configData.Sections[sectionID].Properties = exclusion + return database.exclusionsConfig.Write(configData) +} + +func (database *Database) DeleteExclusion(path string) error { + path = strings.ReplaceAll(path, "\\", "/") + sectionID := fmt.Sprintf("excl-%s", path) + + // Try job-specific exclusions first + files, err := os.ReadDir(database.paths["exclusions"]) + if err != nil { + return fmt.Errorf("DeleteExclusion: error reading directory: %w", err) + } + + for _, file := range files { + if file.Name() == "global" { + continue + } + configPath := filepath.Join(database.paths["exclusions"], file.Name()) + configData, err := database.exclusionsConfig.Parse(configPath) + if err != nil { + continue + } + + if _, exists := configData.Sections[sectionID]; exists { + delete(configData.Sections, sectionID) + newOrder := make([]string, 0) + for _, id := range configData.Order { + if id != sectionID { + newOrder = append(newOrder, id) + } + } + configData.Order = newOrder + + // If the config is empty after deletion, remove the file + if len(configData.Sections) == 0 { + if err := os.Remove(configPath); err != nil { + return fmt.Errorf("DeleteExclusion: error removing empty config file: %w", err) + } + return nil + } + + // Otherwise write the updated config + if err := database.exclusionsConfig.Write(configData); err != nil { + return fmt.Errorf("DeleteExclusion: error writing config: %w", err) + } + return nil + } + } + + // Try global exclusion + globalPath := filepath.Join(database.paths["exclusions"], "global.cfg") + if configData, err := database.exclusionsConfig.Parse(globalPath); err == nil { + if _, exists := configData.Sections[sectionID]; exists { + delete(configData.Sections, sectionID) + newOrder := make([]string, 0) + for _, id := range configData.Order { + if id != sectionID { + newOrder = append(newOrder, id) + } + } + configData.Order = newOrder + + // If the global config is empty after deletion, remove the file + if len(configData.Sections) == 0 { + if err := os.Remove(globalPath); err != nil { + return fmt.Errorf("DeleteExclusion: error removing empty global config file: %w", err) + } + return nil + } + + // Otherwise write the updated config + if err := database.exclusionsConfig.Write(configData); err != nil { + return fmt.Errorf("DeleteExclusion: error writing global config: %w", err) + } + return nil + } + } + + return fmt.Errorf("DeleteExclusion: exclusion not found for path: %s", path) +} diff --git a/internal/store/database/jobs.go b/internal/store/database/jobs.go index d3fb406..a8d57b2 100644 --- a/internal/store/database/jobs.go +++ b/internal/store/database/jobs.go @@ -1,239 +1,239 @@ -//go:build linux - -package database - -import ( - "fmt" - "log" - "os" - "path/filepath" - "strings" - "time" - - configLib "github.com/sonroyaalmerol/pbs-plus/internal/config" - "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" - "github.com/sonroyaalmerol/pbs-plus/internal/store/system" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" -) - -func (database *Database) RegisterJobPlugin() { - plugin := &configLib.SectionPlugin[types.Job]{ - TypeName: "job", - FolderPath: database.paths["jobs"], - Validate: func(config types.Job) error { - if !utils.IsValidNamespace(config.Namespace) && config.Namespace != "" { - return fmt.Errorf("invalid namespace string: %s", config.Namespace) - } - if err := utils.ValidateOnCalendar(config.Schedule); err != nil && config.Schedule != "" { - return fmt.Errorf("invalid schedule string: %s", config.Schedule) - } - if !utils.IsValidPathString(config.Subpath) { - return fmt.Errorf("invalid subpath string: %s", config.Subpath) - } - return nil - }, - } - - database.jobsConfig = configLib.NewSectionConfig(plugin) -} - -func (database *Database) CreateJob(job types.Job) error { - if !utils.IsValidID(job.ID) && job.ID != "" { - return fmt.Errorf("CreateJob: invalid id string -> %s", job.ID) - } - - // Convert job to config format - configData := &configLib.ConfigData[types.Job]{ - Sections: map[string]*configLib.Section[types.Job]{ - job.ID: { - Type: "job", - ID: job.ID, - Properties: types.Job{ - Store: job.Store, - Target: job.Target, - Subpath: job.Subpath, - Schedule: job.Schedule, - Comment: job.Comment, - NotificationMode: job.NotificationMode, - Namespace: job.Namespace, - LastRunUpid: job.LastRunUpid, - }, - }, - }, - Order: []string{job.ID}, - } - - if err := database.jobsConfig.Write(configData); err != nil { - return fmt.Errorf("CreateJob: error writing config: %w", err) - } - - // Handle exclusions - if len(job.Exclusions) > 0 { - for _, exclusion := range job.Exclusions { - err := database.CreateExclusion(exclusion) - if err != nil { - continue - } - } - } - - if err := system.SetSchedule(job); err != nil { - syslog.L.Errorf("CreateJob: error setting schedule: %v", err) - } - - return nil -} - -func (database *Database) GetJob(id string) (*types.Job, error) { - jobPath := filepath.Join(database.paths["jobs"], utils.EncodePath(id)+".cfg") - configData, err := database.jobsConfig.Parse(jobPath) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, fmt.Errorf("GetJob: error reading config: %w", err) - } - - section, exists := configData.Sections[id] - if !exists { - return nil, fmt.Errorf("GetJob: section %s does not exist", id) - } - - lastRunUpid := section.Properties.LastRunUpid - - // Convert config to Job struct - job := §ion.Properties - job.ID = id - job.LastRunUpid = lastRunUpid - - // Get exclusions - exclusions, err := database.GetAllJobExclusions(id) - if err == nil && exclusions != nil { - job.Exclusions = exclusions - pathSlice := []string{} - for _, exclusion := range exclusions { - pathSlice = append(pathSlice, exclusion.Path) - } - job.RawExclusions = strings.Join(pathSlice, "\n") - } - - // Get global exclusions - globalExclusions, err := database.GetAllGlobalExclusions() - if err == nil && globalExclusions != nil { - job.Exclusions = append(job.Exclusions, globalExclusions...) - } - - // Update dynamic fields - if job.LastRunUpid != "" { - task, err := proxmox.Session.GetTaskByUPID(job.LastRunUpid) - if err != nil { - log.Printf("GetJob: error getting task by UPID -> %v\n", err) - } else { - job.LastRunEndtime = &task.EndTime - if task.Status == "stopped" { - job.LastRunState = &task.ExitStatus - tmpDuration := task.EndTime - task.StartTime - job.Duration = &tmpDuration - } else { - tmpDuration := time.Now().Unix() - task.StartTime - job.Duration = &tmpDuration - } - } - } - - // Get next schedule - nextSchedule, err := system.GetNextSchedule(job) - if err == nil && nextSchedule != nil { - nextSchedUnix := nextSchedule.Unix() - job.NextRun = &nextSchedUnix - } - - return job, nil -} - -func (database *Database) UpdateJob(job types.Job) error { - if !utils.IsValidID(job.ID) && job.ID != "" { - return fmt.Errorf("UpdateJob: invalid id string -> %s", job.ID) - } - - // Convert job to config format - configData := &configLib.ConfigData[types.Job]{ - Sections: map[string]*configLib.Section[types.Job]{ - job.ID: { - Type: "job", - ID: job.ID, - Properties: job, - }, - }, - Order: []string{job.ID}, - } - - if err := database.jobsConfig.Write(configData); err != nil { - return fmt.Errorf("UpdateJob: error writing config: %w", err) - } - - // Update exclusions - exclusionPath := filepath.Join(database.paths["exclusions"], job.ID+".cfg") - if err := os.RemoveAll(exclusionPath); err != nil { - return fmt.Errorf("UpdateJob: error removing old exclusions: %w", err) - } - - if len(job.Exclusions) > 0 { - for _, exclusion := range job.Exclusions { - if exclusion.JobID != job.ID { - continue - } - err := database.CreateExclusion(exclusion) - if err != nil { - syslog.L.Errorf("UpdateJob: error creating job exclusion: %v", err) - continue - } - } - } - - if err := system.SetSchedule(job); err != nil { - syslog.L.Errorf("UpdateJob: error setting schedule: %v", err) - } - - return nil -} - -func (database *Database) GetAllJobs() ([]types.Job, error) { - files, err := os.ReadDir(database.paths["jobs"]) - if err != nil { - return nil, fmt.Errorf("GetAllJobs: error reading jobs directory: %w", err) - } - - var jobs []types.Job - for _, file := range files { - if file.IsDir() { - continue - } - - job, err := database.GetJob(utils.DecodePath(strings.TrimSuffix(file.Name(), ".cfg"))) - if err != nil || job == nil { - syslog.L.Errorf("GetAllJobs: error getting job: %v", err) - continue - } - jobs = append(jobs, *job) - } - - return jobs, nil -} - -func (database *Database) DeleteJob(id string) error { - jobPath := filepath.Join(database.paths["jobs"], utils.EncodePath(id)+".cfg") - if err := os.Remove(jobPath); err != nil { - if !os.IsNotExist(err) { - return fmt.Errorf("DeleteJob: error deleting job file: %w", err) - } - } - - if err := system.DeleteSchedule(id); err != nil { - syslog.L.Errorf("DeleteJob: error deleting schedule: %v", err) - } - - return nil -} +//go:build linux + +package database + +import ( + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + configLib "github.com/sonroyaalmerol/pbs-plus/internal/config" + "github.com/sonroyaalmerol/pbs-plus/internal/store/proxmox" + "github.com/sonroyaalmerol/pbs-plus/internal/store/system" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" + "github.com/sonroyaalmerol/pbs-plus/internal/utils" +) + +func (database *Database) RegisterJobPlugin() { + plugin := &configLib.SectionPlugin[types.Job]{ + TypeName: "job", + FolderPath: database.paths["jobs"], + Validate: func(config types.Job) error { + if !utils.IsValidNamespace(config.Namespace) && config.Namespace != "" { + return fmt.Errorf("invalid namespace string: %s", config.Namespace) + } + if err := utils.ValidateOnCalendar(config.Schedule); err != nil && config.Schedule != "" { + return fmt.Errorf("invalid schedule string: %s", config.Schedule) + } + if !utils.IsValidPathString(config.Subpath) { + return fmt.Errorf("invalid subpath string: %s", config.Subpath) + } + return nil + }, + } + + database.jobsConfig = configLib.NewSectionConfig(plugin) +} + +func (database *Database) CreateJob(job types.Job) error { + if !utils.IsValidID(job.ID) && job.ID != "" { + return fmt.Errorf("CreateJob: invalid id string -> %s", job.ID) + } + + // Convert job to config format + configData := &configLib.ConfigData[types.Job]{ + Sections: map[string]*configLib.Section[types.Job]{ + job.ID: { + Type: "job", + ID: job.ID, + Properties: types.Job{ + Store: job.Store, + Target: job.Target, + Subpath: job.Subpath, + Schedule: job.Schedule, + Comment: job.Comment, + NotificationMode: job.NotificationMode, + Namespace: job.Namespace, + LastRunUpid: job.LastRunUpid, + }, + }, + }, + Order: []string{job.ID}, + } + + if err := database.jobsConfig.Write(configData); err != nil { + return fmt.Errorf("CreateJob: error writing config: %w", err) + } + + // Handle exclusions + if len(job.Exclusions) > 0 { + for _, exclusion := range job.Exclusions { + err := database.CreateExclusion(exclusion) + if err != nil { + continue + } + } + } + + if err := system.SetSchedule(job); err != nil { + syslog.L.Errorf("CreateJob: error setting schedule: %v", err) + } + + return nil +} + +func (database *Database) GetJob(id string) (*types.Job, error) { + jobPath := filepath.Join(database.paths["jobs"], utils.EncodePath(id)+".cfg") + configData, err := database.jobsConfig.Parse(jobPath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("GetJob: error reading config: %w", err) + } + + section, exists := configData.Sections[id] + if !exists { + return nil, fmt.Errorf("GetJob: section %s does not exist", id) + } + + lastRunUpid := section.Properties.LastRunUpid + + // Convert config to Job struct + job := §ion.Properties + job.ID = id + job.LastRunUpid = lastRunUpid + + // Get exclusions + exclusions, err := database.GetAllJobExclusions(id) + if err == nil && exclusions != nil { + job.Exclusions = exclusions + pathSlice := []string{} + for _, exclusion := range exclusions { + pathSlice = append(pathSlice, exclusion.Path) + } + job.RawExclusions = strings.Join(pathSlice, "\n") + } + + // Get global exclusions + globalExclusions, err := database.GetAllGlobalExclusions() + if err == nil && globalExclusions != nil { + job.Exclusions = append(job.Exclusions, globalExclusions...) + } + + // Update dynamic fields + if job.LastRunUpid != "" { + task, err := proxmox.Session.GetTaskByUPID(job.LastRunUpid) + if err != nil { + log.Printf("GetJob: error getting task by UPID -> %v\n", err) + } else { + job.LastRunEndtime = &task.EndTime + if task.Status == "stopped" { + job.LastRunState = &task.ExitStatus + tmpDuration := task.EndTime - task.StartTime + job.Duration = &tmpDuration + } else { + tmpDuration := time.Now().Unix() - task.StartTime + job.Duration = &tmpDuration + } + } + } + + // Get next schedule + nextSchedule, err := system.GetNextSchedule(job) + if err == nil && nextSchedule != nil { + nextSchedUnix := nextSchedule.Unix() + job.NextRun = &nextSchedUnix + } + + return job, nil +} + +func (database *Database) UpdateJob(job types.Job) error { + if !utils.IsValidID(job.ID) && job.ID != "" { + return fmt.Errorf("UpdateJob: invalid id string -> %s", job.ID) + } + + // Convert job to config format + configData := &configLib.ConfigData[types.Job]{ + Sections: map[string]*configLib.Section[types.Job]{ + job.ID: { + Type: "job", + ID: job.ID, + Properties: job, + }, + }, + Order: []string{job.ID}, + } + + if err := database.jobsConfig.Write(configData); err != nil { + return fmt.Errorf("UpdateJob: error writing config: %w", err) + } + + // Update exclusions + exclusionPath := filepath.Join(database.paths["exclusions"], job.ID+".cfg") + if err := os.RemoveAll(exclusionPath); err != nil { + return fmt.Errorf("UpdateJob: error removing old exclusions: %w", err) + } + + if len(job.Exclusions) > 0 { + for _, exclusion := range job.Exclusions { + if exclusion.JobID != job.ID { + continue + } + err := database.CreateExclusion(exclusion) + if err != nil { + syslog.L.Errorf("UpdateJob: error creating job exclusion: %v", err) + continue + } + } + } + + if err := system.SetSchedule(job); err != nil { + syslog.L.Errorf("UpdateJob: error setting schedule: %v", err) + } + + return nil +} + +func (database *Database) GetAllJobs() ([]types.Job, error) { + files, err := os.ReadDir(database.paths["jobs"]) + if err != nil { + return nil, fmt.Errorf("GetAllJobs: error reading jobs directory: %w", err) + } + + var jobs []types.Job + for _, file := range files { + if file.IsDir() { + continue + } + + job, err := database.GetJob(utils.DecodePath(strings.TrimSuffix(file.Name(), ".cfg"))) + if err != nil || job == nil { + syslog.L.Errorf("GetAllJobs: error getting job: %v", err) + continue + } + jobs = append(jobs, *job) + } + + return jobs, nil +} + +func (database *Database) DeleteJob(id string) error { + jobPath := filepath.Join(database.paths["jobs"], utils.EncodePath(id)+".cfg") + if err := os.Remove(jobPath); err != nil { + if !os.IsNotExist(err) { + return fmt.Errorf("DeleteJob: error deleting job file: %w", err) + } + } + + if err := system.DeleteSchedule(id); err != nil { + syslog.L.Errorf("DeleteJob: error deleting schedule: %v", err) + } + + return nil +} diff --git a/internal/store/proxmox/http.go b/internal/store/proxmox/http.go index ddb5476..e0d9889 100644 --- a/internal/store/proxmox/http.go +++ b/internal/store/proxmox/http.go @@ -25,7 +25,7 @@ type ProxmoxSession struct { func InitializeProxmox() { Session = &ProxmoxSession{ HTTPClient: &http.Client{ - Timeout: time.Second * 30, + Timeout: time.Minute * 2, Transport: utils.BaseTransport, }, } diff --git a/internal/store/proxmox/refresh_token.go b/internal/store/proxmox/refresh_token.go index aeaf5ad..84da01b 100644 --- a/internal/store/proxmox/refresh_token.go +++ b/internal/store/proxmox/refresh_token.go @@ -1,147 +1,147 @@ -//go:build linux - -package proxmox - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strings" - - "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" -) - -type Token struct { - CSRFToken string `json:"CSRFPreventionToken"` - Ticket string `json:"ticket"` - Username string `json:"username"` -} - -type TokenResponse struct { - Data Token `json:"data"` -} - -type TokenRequest struct { - Username string `json:"username"` - Password string `json:"password"` -} - -type APITokenRequest struct { - Comment string `json:"comment"` -} - -type APITokenResponse struct { - Data APIToken `json:"data"` -} - -type ACLRequest struct { - Path string `json:"path"` - Role string `json:"role"` - AuthId string `json:"auth-id"` -} - -type APIToken struct { - TokenId string `json:"tokenid"` - Value string `json:"value"` -} - -func (proxmoxSess *ProxmoxSession) CreateAPIToken() (*APIToken, error) { - if proxmoxSess.LastToken == nil { - return nil, fmt.Errorf("CreateAPIToken: token required") - } - - _ = proxmoxSess.ProxmoxHTTPRequest( - http.MethodDelete, - fmt.Sprintf("/api2/json/access/users/%s/token/pbs-plus-auth", proxmoxSess.LastToken.Username), - nil, - nil, - ) - - reqBody, err := json.Marshal(&APITokenRequest{ - Comment: "Autogenerated API token for PBS Plus", - }) - if err != nil { - return nil, fmt.Errorf("CreateAPIToken: error creating req body -> %w", err) - } - - var tokenResp APITokenResponse - err = proxmoxSess.ProxmoxHTTPRequest( - http.MethodPost, - fmt.Sprintf("/api2/json/access/users/%s/token/pbs-plus-auth", proxmoxSess.LastToken.Username), - bytes.NewBuffer(reqBody), - &tokenResp, - ) - if err != nil { - if !strings.Contains(err.Error(), "already exists") { - return nil, fmt.Errorf("CreateAPIToken: error executing http request token post -> %w", err) - } - } - - aclBody, err := json.Marshal(&ACLRequest{ - AuthId: tokenResp.Data.TokenId, - Role: "Admin", - Path: "/", - }) - if err != nil { - return nil, fmt.Errorf("CreateAPIToken: error creating acl body -> %w", err) - } - - err = proxmoxSess.ProxmoxHTTPRequest( - http.MethodPut, - "/api2/json/access/acl", - bytes.NewBuffer(aclBody), - nil, - ) - if err != nil { - if !strings.Contains(err.Error(), "already exists") { - return nil, fmt.Errorf("CreateAPIToken: error executing http request acl put -> %w", err) - } - } - - return &tokenResp.Data, nil -} - -func (token *APIToken) SaveToFile() error { - if token == nil { - return nil - } - - tokenFileContent, _ := json.Marshal(token) - file, err := os.OpenFile(filepath.Join(constants.DbBasePath, "pbs-plus-token.json"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return err - } - defer file.Close() - - _, err = file.WriteString(string(tokenFileContent)) - if err != nil { - return err - } - - return nil -} - -func GetAPITokenFromFile() (*APIToken, error) { - jsonFile, err := os.Open(filepath.Join(constants.DbBasePath, "pbs-plus-token.json")) - if err != nil { - return nil, err - } - defer jsonFile.Close() - - byteValue, err := io.ReadAll(jsonFile) - if err != nil { - return nil, err - } - - var result APIToken - err = json.Unmarshal([]byte(byteValue), &result) - if err != nil { - return nil, err - } - - return &result, nil -} +//go:build linux + +package proxmox + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/sonroyaalmerol/pbs-plus/internal/store/constants" +) + +type Token struct { + CSRFToken string `json:"CSRFPreventionToken"` + Ticket string `json:"ticket"` + Username string `json:"username"` +} + +type TokenResponse struct { + Data Token `json:"data"` +} + +type TokenRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type APITokenRequest struct { + Comment string `json:"comment"` +} + +type APITokenResponse struct { + Data APIToken `json:"data"` +} + +type ACLRequest struct { + Path string `json:"path"` + Role string `json:"role"` + AuthId string `json:"auth-id"` +} + +type APIToken struct { + TokenId string `json:"tokenid"` + Value string `json:"value"` +} + +func (proxmoxSess *ProxmoxSession) CreateAPIToken() (*APIToken, error) { + if proxmoxSess.LastToken == nil { + return nil, fmt.Errorf("CreateAPIToken: token required") + } + + _ = proxmoxSess.ProxmoxHTTPRequest( + http.MethodDelete, + fmt.Sprintf("/api2/json/access/users/%s/token/pbs-plus-auth", proxmoxSess.LastToken.Username), + nil, + nil, + ) + + reqBody, err := json.Marshal(&APITokenRequest{ + Comment: "Autogenerated API token for PBS Plus", + }) + if err != nil { + return nil, fmt.Errorf("CreateAPIToken: error creating req body -> %w", err) + } + + var tokenResp APITokenResponse + err = proxmoxSess.ProxmoxHTTPRequest( + http.MethodPost, + fmt.Sprintf("/api2/json/access/users/%s/token/pbs-plus-auth", proxmoxSess.LastToken.Username), + bytes.NewBuffer(reqBody), + &tokenResp, + ) + if err != nil { + if !strings.Contains(err.Error(), "already exists") { + return nil, fmt.Errorf("CreateAPIToken: error executing http request token post -> %w", err) + } + } + + aclBody, err := json.Marshal(&ACLRequest{ + AuthId: tokenResp.Data.TokenId, + Role: "Admin", + Path: "/", + }) + if err != nil { + return nil, fmt.Errorf("CreateAPIToken: error creating acl body -> %w", err) + } + + err = proxmoxSess.ProxmoxHTTPRequest( + http.MethodPut, + "/api2/json/access/acl", + bytes.NewBuffer(aclBody), + nil, + ) + if err != nil { + if !strings.Contains(err.Error(), "already exists") { + return nil, fmt.Errorf("CreateAPIToken: error executing http request acl put -> %w", err) + } + } + + return &tokenResp.Data, nil +} + +func (token *APIToken) SaveToFile() error { + if token == nil { + return nil + } + + tokenFileContent, _ := json.Marshal(token) + file, err := os.OpenFile(filepath.Join(constants.DbBasePath, "pbs-plus-token.json"), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer file.Close() + + _, err = file.WriteString(string(tokenFileContent)) + if err != nil { + return err + } + + return nil +} + +func GetAPITokenFromFile() (*APIToken, error) { + jsonFile, err := os.Open(filepath.Join(constants.DbBasePath, "pbs-plus-token.json")) + if err != nil { + return nil, err + } + defer jsonFile.Close() + + byteValue, err := io.ReadAll(jsonFile) + if err != nil { + return nil, err + } + + var result APIToken + err = json.Unmarshal([]byte(byteValue), &result) + if err != nil { + return nil, err + } + + return &result, nil +} diff --git a/internal/store/proxmox/tasks.go b/internal/store/proxmox/tasks.go index 62df42c..1984aea 100644 --- a/internal/store/proxmox/tasks.go +++ b/internal/store/proxmox/tasks.go @@ -1,224 +1,224 @@ -//go:build linux - -package proxmox - -import ( - "context" - "fmt" - "log" - "net/http" - "os" - "path/filepath" - "strings" - - "github.com/fsnotify/fsnotify" - "github.com/sonroyaalmerol/pbs-plus/internal/store/types" -) - -type TaskResponse struct { - Data Task `json:"data"` - Total int `json:"total"` -} - -type Task struct { - WID string `json:"id"` - Node string `json:"node"` - PID int `json:"pid"` - PStart int `json:"pstart"` - StartTime int64 `json:"starttime"` - EndTime int64 `json:"endtime"` - UPID string `json:"upid"` - User string `json:"user"` - WorkerType string `json:"worker_type"` - Status string `json:"status"` - ExitStatus string `json:"exitstatus"` -} - -func isDir(path string) bool { - info, err := os.Stat(path) - if err != nil { - log.Println("Error checking path:", err) - return false - } - return info.IsDir() -} - -func encodeToHexEscapes(input string) string { - var encoded strings.Builder - for _, char := range input { - if char >= 'a' && char <= 'z' || char >= 'A' && char <= 'Z' || char >= '0' && char <= '9' { - encoded.WriteRune(char) - } else { - encoded.WriteString(fmt.Sprintf(`\x%02x`, char)) - } - } - - return encoded.String() -} - -func (proxmoxSess *ProxmoxSession) GetJobTask(ctx context.Context, readyChan chan struct{}, job *types.Job, target *types.Target) (*Task, error) { - tasksParentPath := "/var/log/proxmox-backup/tasks" - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, fmt.Errorf("failed to create watcher: %w", err) - } - err = watcher.Add(tasksParentPath) - if err != nil { - return nil, fmt.Errorf("failed to add folder to watcher: %w", err) - } - - // Helper function to check if a file matches our search criteria - checkFile := func(filePath string, searchString string) (*Task, error) { - if !strings.Contains(filePath, ".tmp_") && strings.Contains(filePath, searchString) { - log.Printf("Proceeding: %s contains %s\n", filePath, searchString) - fileName := filepath.Base(filePath) - log.Printf("Getting UPID: %s\n", fileName) - newTask, err := proxmoxSess.GetTaskByUPID(fileName) - if err != nil { - return nil, fmt.Errorf("GetJobTask: error getting task: %v\n", err) - } - log.Printf("Sending UPID: %s\n", fileName) - return newTask, nil - } - return nil, nil - } - - // Helper function to scan directory for matching files - scanDirectory := func(dirPath string, searchString string) (*Task, error) { - files, err := os.ReadDir(dirPath) - if err != nil { - log.Printf("Error reading directory %s: %v\n", dirPath, err) - return nil, nil - } - - for _, file := range files { - if !file.IsDir() { - filePath := filepath.Join(dirPath, file.Name()) - task, err := checkFile(filePath, searchString) - if err != nil { - return nil, err - } - if task != nil { - return task, nil - } - } - } - return nil, nil - } - - err = filepath.Walk(tasksParentPath, func(path string, info os.FileInfo, err error) error { - if err != nil { - log.Println("Error walking the path:", err) - return err - } - if info.IsDir() { - err = watcher.Add(path) - if err != nil { - log.Println("Failed to add directory to watcher:", err) - } - } - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to walk folder: %w", err) - } - - hostname, err := os.Hostname() - if err != nil { - hostnameFile, err := os.ReadFile("/etc/hostname") - if err != nil { - hostname = "localhost" - } - hostname = strings.TrimSpace(string(hostnameFile)) - } - - isAgent := strings.HasPrefix(target.Path, "agent://") - backupId := hostname - if isAgent { - backupId = strings.TrimSpace(strings.Split(target.Name, " - ")[0]) - } - - searchString := fmt.Sprintf(":backup:%s%shost-%s", job.Store, encodeToHexEscapes(":"), encodeToHexEscapes(backupId)) - - close(readyChan) - defer watcher.Close() - - for { - select { - case event := <-watcher.Events: - if event.Op&fsnotify.Create == fsnotify.Create { - if isDir(event.Name) { - err = watcher.Add(event.Name) - if err != nil { - log.Println("Failed to add directory to watcher:", err) - } - - task, err := scanDirectory(event.Name, searchString) - if err != nil { - return nil, err - } - if task != nil { - return task, nil - } - } else { - task, err := checkFile(event.Name, searchString) - if err != nil { - return nil, err - } - if task != nil { - return task, nil - } - } - } - case <-ctx.Done(): - return nil, nil - } - } -} - -func (proxmoxSess *ProxmoxSession) GetTaskByUPID(upid string) (*Task, error) { - var resp TaskResponse - err := proxmoxSess.ProxmoxHTTPRequest( - http.MethodGet, - fmt.Sprintf("/api2/json/nodes/localhost/tasks/%s/status", upid), - nil, - &resp, - ) - if err != nil { - return nil, fmt.Errorf("GetTaskByUPID: error creating http request -> %w", err) - } - - if resp.Data.Status == "stopped" { - endTime, err := proxmoxSess.GetTaskEndTime(&resp.Data) - if err != nil { - return nil, fmt.Errorf("GetTaskByUPID: error getting task end time -> %w", err) - } - - resp.Data.EndTime = endTime - } - - return &resp.Data, nil -} - -func (proxmoxSess *ProxmoxSession) GetTaskEndTime(task *Task) (int64, error) { - if proxmoxSess.LastToken == nil && proxmoxSess.APIToken == nil { - return -1, fmt.Errorf("GetTaskEndTime: token is required") - } - - upidSplit := strings.Split(task.UPID, ":") - if len(upidSplit) < 4 { - return -1, fmt.Errorf("GetTaskEndTime: error getting tasks: invalid upid") - } - - parsed := upidSplit[3] - logFolder := parsed[len(parsed)-2:] - - logPath := fmt.Sprintf("/var/log/proxmox-backup/tasks/%s/%s", logFolder, task.UPID) - - logStat, err := os.Stat(logPath) - if err == nil { - return logStat.ModTime().Unix(), nil - } - - return -1, fmt.Errorf("GetTaskEndTime: error getting tasks: not found (%s) -> %w", logPath, err) -} +//go:build linux + +package proxmox + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/fsnotify/fsnotify" + "github.com/sonroyaalmerol/pbs-plus/internal/store/types" +) + +type TaskResponse struct { + Data Task `json:"data"` + Total int `json:"total"` +} + +type Task struct { + WID string `json:"id"` + Node string `json:"node"` + PID int `json:"pid"` + PStart int `json:"pstart"` + StartTime int64 `json:"starttime"` + EndTime int64 `json:"endtime"` + UPID string `json:"upid"` + User string `json:"user"` + WorkerType string `json:"worker_type"` + Status string `json:"status"` + ExitStatus string `json:"exitstatus"` +} + +func isDir(path string) bool { + info, err := os.Stat(path) + if err != nil { + log.Println("Error checking path:", err) + return false + } + return info.IsDir() +} + +func encodeToHexEscapes(input string) string { + var encoded strings.Builder + for _, char := range input { + if char >= 'a' && char <= 'z' || char >= 'A' && char <= 'Z' || char >= '0' && char <= '9' { + encoded.WriteRune(char) + } else { + encoded.WriteString(fmt.Sprintf(`\x%02x`, char)) + } + } + + return encoded.String() +} + +func (proxmoxSess *ProxmoxSession) GetJobTask(ctx context.Context, readyChan chan struct{}, job *types.Job, target *types.Target) (*Task, error) { + tasksParentPath := "/var/log/proxmox-backup/tasks" + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("failed to create watcher: %w", err) + } + err = watcher.Add(tasksParentPath) + if err != nil { + return nil, fmt.Errorf("failed to add folder to watcher: %w", err) + } + + // Helper function to check if a file matches our search criteria + checkFile := func(filePath string, searchString string) (*Task, error) { + if !strings.Contains(filePath, ".tmp_") && strings.Contains(filePath, searchString) { + log.Printf("Proceeding: %s contains %s\n", filePath, searchString) + fileName := filepath.Base(filePath) + log.Printf("Getting UPID: %s\n", fileName) + newTask, err := proxmoxSess.GetTaskByUPID(fileName) + if err != nil { + return nil, fmt.Errorf("GetJobTask: error getting task: %v\n", err) + } + log.Printf("Sending UPID: %s\n", fileName) + return newTask, nil + } + return nil, nil + } + + // Helper function to scan directory for matching files + scanDirectory := func(dirPath string, searchString string) (*Task, error) { + files, err := os.ReadDir(dirPath) + if err != nil { + log.Printf("Error reading directory %s: %v\n", dirPath, err) + return nil, nil + } + + for _, file := range files { + if !file.IsDir() { + filePath := filepath.Join(dirPath, file.Name()) + task, err := checkFile(filePath, searchString) + if err != nil { + return nil, err + } + if task != nil { + return task, nil + } + } + } + return nil, nil + } + + err = filepath.Walk(tasksParentPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + log.Println("Error walking the path:", err) + return err + } + if info.IsDir() { + err = watcher.Add(path) + if err != nil { + log.Println("Failed to add directory to watcher:", err) + } + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to walk folder: %w", err) + } + + hostname, err := os.Hostname() + if err != nil { + hostnameFile, err := os.ReadFile("/etc/hostname") + if err != nil { + hostname = "localhost" + } + hostname = strings.TrimSpace(string(hostnameFile)) + } + + isAgent := strings.HasPrefix(target.Path, "agent://") + backupId := hostname + if isAgent { + backupId = strings.TrimSpace(strings.Split(target.Name, " - ")[0]) + } + + searchString := fmt.Sprintf(":backup:%s%shost-%s", job.Store, encodeToHexEscapes(":"), encodeToHexEscapes(backupId)) + + close(readyChan) + defer watcher.Close() + + for { + select { + case event := <-watcher.Events: + if event.Op&fsnotify.Create == fsnotify.Create { + if isDir(event.Name) { + err = watcher.Add(event.Name) + if err != nil { + log.Println("Failed to add directory to watcher:", err) + } + + task, err := scanDirectory(event.Name, searchString) + if err != nil { + return nil, err + } + if task != nil { + return task, nil + } + } else { + task, err := checkFile(event.Name, searchString) + if err != nil { + return nil, err + } + if task != nil { + return task, nil + } + } + } + case <-ctx.Done(): + return nil, nil + } + } +} + +func (proxmoxSess *ProxmoxSession) GetTaskByUPID(upid string) (*Task, error) { + var resp TaskResponse + err := proxmoxSess.ProxmoxHTTPRequest( + http.MethodGet, + fmt.Sprintf("/api2/json/nodes/localhost/tasks/%s/status", upid), + nil, + &resp, + ) + if err != nil { + return nil, fmt.Errorf("GetTaskByUPID: error creating http request -> %w", err) + } + + if resp.Data.Status == "stopped" { + endTime, err := proxmoxSess.GetTaskEndTime(&resp.Data) + if err != nil { + return nil, fmt.Errorf("GetTaskByUPID: error getting task end time -> %w", err) + } + + resp.Data.EndTime = endTime + } + + return &resp.Data, nil +} + +func (proxmoxSess *ProxmoxSession) GetTaskEndTime(task *Task) (int64, error) { + if proxmoxSess.LastToken == nil && proxmoxSess.APIToken == nil { + return -1, fmt.Errorf("GetTaskEndTime: token is required") + } + + upidSplit := strings.Split(task.UPID, ":") + if len(upidSplit) < 4 { + return -1, fmt.Errorf("GetTaskEndTime: error getting tasks: invalid upid") + } + + parsed := upidSplit[3] + logFolder := parsed[len(parsed)-2:] + + logPath := fmt.Sprintf("/var/log/proxmox-backup/tasks/%s/%s", logFolder, task.UPID) + + logStat, err := os.Stat(logPath) + if err == nil { + return logStat.ModTime().Unix(), nil + } + + return -1, fmt.Errorf("GetTaskEndTime: error getting tasks: not found (%s) -> %w", logPath, err) +} diff --git a/internal/store/store.go b/internal/store/store.go index adbfbcb..e1f02eb 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -1,32 +1,32 @@ -//go:build linux - -package store - -import ( - "fmt" - - "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" - "github.com/sonroyaalmerol/pbs-plus/internal/store/database" - "github.com/sonroyaalmerol/pbs-plus/internal/websockets" -) - -// Store holds the configuration system -type Store struct { - WSHub *websockets.Server - CertGenerator *certificates.Generator - Database *database.Database -} - -func Initialize(wsHub *websockets.Server, paths map[string]string) (*Store, error) { - database, err := database.Initialize(paths) - if err != nil { - return nil, fmt.Errorf("Initialize: error initializing database -> %w", err) - } - - store := &Store{ - WSHub: wsHub, - Database: database, - } - - return store, nil -} +//go:build linux + +package store + +import ( + "fmt" + + "github.com/sonroyaalmerol/pbs-plus/internal/auth/certificates" + "github.com/sonroyaalmerol/pbs-plus/internal/store/database" + "github.com/sonroyaalmerol/pbs-plus/internal/websockets" +) + +// Store holds the configuration system +type Store struct { + WSHub *websockets.Server + CertGenerator *certificates.Generator + Database *database.Database +} + +func Initialize(wsHub *websockets.Server, paths map[string]string) (*Store, error) { + database, err := database.Initialize(paths) + if err != nil { + return nil, fmt.Errorf("Initialize: error initializing database -> %w", err) + } + + store := &Store{ + WSHub: wsHub, + Database: database, + } + + return store, nil +} diff --git a/internal/utils/backup_helpers.go b/internal/utils/backup_helpers.go index 122df64..b9287f5 100644 --- a/internal/utils/backup_helpers.go +++ b/internal/utils/backup_helpers.go @@ -1,82 +1,82 @@ -package utils - -import ( - "bufio" - "fmt" - "os" - "strings" - "time" - - "github.com/fsnotify/fsnotify" -) - -func WaitForLogFile(taskUpid string, maxWait time.Duration) error { - // Path to the active tasks - logPath := "/var/log/proxmox-backup/tasks/active" - - if _, found := checkForLine(logPath, taskUpid); found { - return nil - } - - // Create new watcher - watcher, err := fsnotify.NewWatcher() - if err != nil { - return fmt.Errorf("error creating watcher: %w", err) - } - defer watcher.Close() - - // Start watching the file - err = watcher.Add(logPath) - if err != nil { - return fmt.Errorf("error watching file %s: %w", logPath, err) - } - - // Create a timeout channel - timeout := time.After(maxWait) - - // First check if the line already exists - if _, found := checkForLine(logPath, taskUpid); found { - return nil - } - - for { - select { - case event, ok := <-watcher.Events: - if !ok { - return fmt.Errorf("watcher channel closed") - } - - if event.Op&fsnotify.Write == fsnotify.Write { - if _, found := checkForLine(logPath, taskUpid); found { - return nil - } - } - - case err, ok := <-watcher.Errors: - if !ok { - return fmt.Errorf("watcher error channel closed") - } - return fmt.Errorf("watcher error: %w", err) - - case <-timeout: - return fmt.Errorf("timeout waiting for log file after %v", maxWait) - } - } -} - -func checkForLine(filePath, taskUpid string) (*os.File, bool) { - file, err := os.Open(filePath) - if err != nil { - return nil, false - } - defer file.Close() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - if strings.Contains(scanner.Text(), taskUpid) { - return file, true - } - } - - return nil, false -} +package utils + +import ( + "bufio" + "fmt" + "os" + "strings" + "time" + + "github.com/fsnotify/fsnotify" +) + +func WaitForLogFile(taskUpid string, maxWait time.Duration) error { + // Path to the active tasks + logPath := "/var/log/proxmox-backup/tasks/active" + + if _, found := checkForLine(logPath, taskUpid); found { + return nil + } + + // Create new watcher + watcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("error creating watcher: %w", err) + } + defer watcher.Close() + + // Start watching the file + err = watcher.Add(logPath) + if err != nil { + return fmt.Errorf("error watching file %s: %w", logPath, err) + } + + // Create a timeout channel + timeout := time.After(maxWait) + + // First check if the line already exists + if _, found := checkForLine(logPath, taskUpid); found { + return nil + } + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return fmt.Errorf("watcher channel closed") + } + + if event.Op&fsnotify.Write == fsnotify.Write { + if _, found := checkForLine(logPath, taskUpid); found { + return nil + } + } + + case err, ok := <-watcher.Errors: + if !ok { + return fmt.Errorf("watcher error channel closed") + } + return fmt.Errorf("watcher error: %w", err) + + case <-timeout: + return fmt.Errorf("timeout waiting for log file after %v", maxWait) + } + } +} + +func checkForLine(filePath, taskUpid string) (*os.File, bool) { + file, err := os.Open(filePath) + if err != nil { + return nil, false + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + if strings.Contains(scanner.Text(), taskUpid) { + return file, true + } + } + + return nil, false +} diff --git a/internal/utils/digest.go b/internal/utils/digest.go index f1a92c2..3fa0fd9 100644 --- a/internal/utils/digest.go +++ b/internal/utils/digest.go @@ -1,27 +1,27 @@ -package utils - -import ( - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" -) - -func CalculateDigest(data any) (string, error) { - jsonData, err := json.Marshal(data) - if err != nil { - return "", fmt.Errorf("CalculateDigest: failed to marshal data to JSON -> %w", err) - } - - if string(jsonData) == "[]" || string(jsonData) == "{}" { - jsonData = []byte{} - } - - hash := sha256.New() - - hash.Write(jsonData) - - digest := hash.Sum(nil) - - return hex.EncodeToString(digest), nil -} +package utils + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" +) + +func CalculateDigest(data any) (string, error) { + jsonData, err := json.Marshal(data) + if err != nil { + return "", fmt.Errorf("CalculateDigest: failed to marshal data to JSON -> %w", err) + } + + if string(jsonData) == "[]" || string(jsonData) == "{}" { + jsonData = []byte{} + } + + hash := sha256.New() + + hash.Write(jsonData) + + digest := hash.Sum(nil) + + return hex.EncodeToString(digest), nil +} diff --git a/internal/utils/is_mounted.go b/internal/utils/is_mounted.go index f2a647c..94e536c 100644 --- a/internal/utils/is_mounted.go +++ b/internal/utils/is_mounted.go @@ -1,27 +1,27 @@ -package utils - -import ( - "bufio" - "os" - "strings" -) - -func IsMounted(path string) bool { - // Open /proc/self/mountinfo to check mounts - mountInfoFile, err := os.Open("/proc/self/mountinfo") - if err != nil { - return false - } - defer mountInfoFile.Close() - - scanner := bufio.NewScanner(mountInfoFile) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) >= 5 && fields[4] == path { - return true - } - } - - return false -} +package utils + +import ( + "bufio" + "os" + "strings" +) + +func IsMounted(path string) bool { + // Open /proc/self/mountinfo to check mounts + mountInfoFile, err := os.Open("/proc/self/mountinfo") + if err != nil { + return false + } + defer mountInfoFile.Close() + + scanner := bufio.NewScanner(mountInfoFile) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 5 && fields[4] == path { + return true + } + } + + return false +} diff --git a/internal/utils/local_drives.go b/internal/utils/local_drives.go index 5c52976..3ae511d 100644 --- a/internal/utils/local_drives.go +++ b/internal/utils/local_drives.go @@ -1,137 +1,137 @@ -//go:build windows - -package utils - -import ( - "fmt" - - "golang.org/x/sys/windows" -) - -// DriveInfo contains detailed information about a drive - -// getDriveTypeString returns a human-readable string describing the type of drive -func getDriveTypeString(dt uint32) string { - switch dt { - case windows.DRIVE_UNKNOWN: - return "Unknown" - case windows.DRIVE_NO_ROOT_DIR: - return "No Root Directory" - case windows.DRIVE_REMOVABLE: - return "Removable" - case windows.DRIVE_FIXED: - return "Fixed" - case windows.DRIVE_REMOTE: - return "Network" - case windows.DRIVE_CDROM: - return "CD-ROM" - case windows.DRIVE_RAMDISK: - return "RAM Disk" - default: - return "Unknown" - } -} - -// humanizeBytes converts a byte count into a human-readable string with appropriate units (KB, MB, GB, TB) -func humanizeBytes(bytes uint64) string { - const unit = 1000 - if bytes < unit { - return fmt.Sprintf("%d B", bytes) - } - div, exp := unit, 0 - for n := bytes / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - var unitSymbol string - switch exp { - case 0: - unitSymbol = "KB" - case 1: - unitSymbol = "MB" - case 2: - unitSymbol = "GB" - case 3: - unitSymbol = "TB" - case 4: - unitSymbol = "PB" - default: - unitSymbol = "??" - } - return fmt.Sprintf("%.2f %s", float64(bytes)/float64(div), unitSymbol) -} - -// GetLocalDrives returns a slice of DriveInfo containing detailed information about each local drive -func GetLocalDrives() ([]DriveInfo, error) { - var drives []DriveInfo - - for _, drive := range "ABCDEFGHIJKLMNOPQRSTUVWXYZ" { - path := fmt.Sprintf("%c:\\", drive) - pathUtf16, err := windows.UTF16PtrFromString(path) - if err != nil { - continue // Skip invalid paths - } - - driveType := windows.GetDriveType(pathUtf16) - if driveType == windows.DRIVE_NO_ROOT_DIR { - continue // Drive not present - } - - var ( - volumeNameStr string - fileSystemStr string - totalBytes uint64 - freeBytes uint64 - ) - - // Retrieve volume information - var volumeName [windows.MAX_PATH + 1]uint16 - var fileSystemName [windows.MAX_PATH + 1]uint16 - if err := windows.GetVolumeInformation( - pathUtf16, - &volumeName[0], - uint32(len(volumeName)), - nil, - nil, - nil, - &fileSystemName[0], - uint32(len(fileSystemName)), - ); err == nil { - volumeNameStr = windows.UTF16ToString(volumeName[:]) - fileSystemStr = windows.UTF16ToString(fileSystemName[:]) - } - - // Retrieve disk space information - var totalFreeBytes uint64 - if err := windows.GetDiskFreeSpaceEx( - pathUtf16, - nil, - &totalBytes, - &totalFreeBytes, - ); err == nil { - freeBytes = totalFreeBytes - } - - usedBytes := totalBytes - freeBytes - - // Humanize byte counts - totalHuman := humanizeBytes(totalBytes) - usedHuman := humanizeBytes(usedBytes) - freeHuman := humanizeBytes(freeBytes) - - drives = append(drives, DriveInfo{ - Letter: string(drive), - Type: getDriveTypeString(driveType), - VolumeName: volumeNameStr, - FileSystem: fileSystemStr, - TotalBytes: totalBytes, - UsedBytes: usedBytes, - FreeBytes: freeBytes, - Total: totalHuman, - Used: usedHuman, - Free: freeHuman, - }) - } - - return drives, nil -} +//go:build windows + +package utils + +import ( + "fmt" + + "golang.org/x/sys/windows" +) + +// DriveInfo contains detailed information about a drive + +// getDriveTypeString returns a human-readable string describing the type of drive +func getDriveTypeString(dt uint32) string { + switch dt { + case windows.DRIVE_UNKNOWN: + return "Unknown" + case windows.DRIVE_NO_ROOT_DIR: + return "No Root Directory" + case windows.DRIVE_REMOVABLE: + return "Removable" + case windows.DRIVE_FIXED: + return "Fixed" + case windows.DRIVE_REMOTE: + return "Network" + case windows.DRIVE_CDROM: + return "CD-ROM" + case windows.DRIVE_RAMDISK: + return "RAM Disk" + default: + return "Unknown" + } +} + +// humanizeBytes converts a byte count into a human-readable string with appropriate units (KB, MB, GB, TB) +func humanizeBytes(bytes uint64) string { + const unit = 1000 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := unit, 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + var unitSymbol string + switch exp { + case 0: + unitSymbol = "KB" + case 1: + unitSymbol = "MB" + case 2: + unitSymbol = "GB" + case 3: + unitSymbol = "TB" + case 4: + unitSymbol = "PB" + default: + unitSymbol = "??" + } + return fmt.Sprintf("%.2f %s", float64(bytes)/float64(div), unitSymbol) +} + +// GetLocalDrives returns a slice of DriveInfo containing detailed information about each local drive +func GetLocalDrives() ([]DriveInfo, error) { + var drives []DriveInfo + + for _, drive := range "ABCDEFGHIJKLMNOPQRSTUVWXYZ" { + path := fmt.Sprintf("%c:\\", drive) + pathUtf16, err := windows.UTF16PtrFromString(path) + if err != nil { + continue // Skip invalid paths + } + + driveType := windows.GetDriveType(pathUtf16) + if driveType == windows.DRIVE_NO_ROOT_DIR { + continue // Drive not present + } + + var ( + volumeNameStr string + fileSystemStr string + totalBytes uint64 + freeBytes uint64 + ) + + // Retrieve volume information + var volumeName [windows.MAX_PATH + 1]uint16 + var fileSystemName [windows.MAX_PATH + 1]uint16 + if err := windows.GetVolumeInformation( + pathUtf16, + &volumeName[0], + uint32(len(volumeName)), + nil, + nil, + nil, + &fileSystemName[0], + uint32(len(fileSystemName)), + ); err == nil { + volumeNameStr = windows.UTF16ToString(volumeName[:]) + fileSystemStr = windows.UTF16ToString(fileSystemName[:]) + } + + // Retrieve disk space information + var totalFreeBytes uint64 + if err := windows.GetDiskFreeSpaceEx( + pathUtf16, + nil, + &totalBytes, + &totalFreeBytes, + ); err == nil { + freeBytes = totalFreeBytes + } + + usedBytes := totalBytes - freeBytes + + // Humanize byte counts + totalHuman := humanizeBytes(totalBytes) + usedHuman := humanizeBytes(usedBytes) + freeHuman := humanizeBytes(freeBytes) + + drives = append(drives, DriveInfo{ + Letter: string(drive), + Type: getDriveTypeString(driveType), + VolumeName: volumeNameStr, + FileSystem: fileSystemStr, + TotalBytes: totalBytes, + UsedBytes: usedBytes, + FreeBytes: freeBytes, + Total: totalHuman, + Used: usedHuman, + Free: freeHuman, + }) + } + + return drives, nil +} diff --git a/internal/utils/local_ips.go b/internal/utils/local_ips.go index 5f2632d..0aed107 100644 --- a/internal/utils/local_ips.go +++ b/internal/utils/local_ips.go @@ -1,57 +1,57 @@ -package utils - -import ( - "fmt" - "net" - "net/http" - "strings" -) - -func GetLocalIPs() ([]string, error) { - var ips []string - ifaces, err := net.Interfaces() - if err != nil { - return nil, err - } - - for _, iface := range ifaces { - if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { - continue - } - - addrs, err := iface.Addrs() - if err != nil { - continue - } - - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - } - if ip != nil && ip.To4() != nil { - ips = append(ips, ip.String()) - } - } - } - return ips, nil -} - -func IsRequestFromSelf(r *http.Request) bool { - remoteIP := strings.Split(r.RemoteAddr, ":")[0] // Extract the IP part - localIPs, err := GetLocalIPs() - if err != nil { - fmt.Println("Error fetching local IPs:", err) - return false - } - - for _, ip := range localIPs { - if remoteIP == ip { - return true - } - } - return false -} +package utils + +import ( + "fmt" + "net" + "net/http" + "strings" +) + +func GetLocalIPs() ([]string, error) { + var ips []string + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip != nil && ip.To4() != nil { + ips = append(ips, ip.String()) + } + } + } + return ips, nil +} + +func IsRequestFromSelf(r *http.Request) bool { + remoteIP := strings.Split(r.RemoteAddr, ":")[0] // Extract the IP part + localIPs, err := GetLocalIPs() + if err != nil { + fmt.Println("Error fetching local IPs:", err) + return false + } + + for _, ip := range localIPs { + if remoteIP == ip { + return true + } + } + return false +} diff --git a/internal/utils/path_check.go b/internal/utils/path_check.go index 39a3378..638da6a 100644 --- a/internal/utils/path_check.go +++ b/internal/utils/path_check.go @@ -1,53 +1,53 @@ -package utils - -import ( - "os" - "path/filepath" - "strings" - "unicode" -) - -const safeDir = "/home/user/" - -func IsValid(path string) bool { - // Check if path is not empty - if path == "" { - return false - } - - // Resolve the input path with respect to the safe directory - absPath, err := filepath.Abs(filepath.Join(safeDir, path)) - if err != nil || !strings.HasPrefix(absPath, safeDir) { - return false - } - - // Check if the path exists - _, err = os.Stat(absPath) - if err != nil { - if os.IsNotExist(err) { - return false - } - return false - } - - // Path exists, return true and no error - return true -} - -func IsValidPathString(path string) bool { - if path == "" { - return true - } - - if strings.Contains(path, "//") { - return false - } - - for _, r := range path { - if r == 0 || !unicode.IsPrint(r) { - return false - } - } - - return true -} +package utils + +import ( + "os" + "path/filepath" + "strings" + "unicode" +) + +const safeDir = "/home/user/" + +func IsValid(path string) bool { + // Check if path is not empty + if path == "" { + return false + } + + // Resolve the input path with respect to the safe directory + absPath, err := filepath.Abs(filepath.Join(safeDir, path)) + if err != nil || !strings.HasPrefix(absPath, safeDir) { + return false + } + + // Check if the path exists + _, err = os.Stat(absPath) + if err != nil { + if os.IsNotExist(err) { + return false + } + return false + } + + // Path exists, return true and no error + return true +} + +func IsValidPathString(path string) bool { + if path == "" { + return true + } + + if strings.Contains(path, "//") { + return false + } + + for _, r := range path { + if r == 0 || !unicode.IsPrint(r) { + return false + } + } + + return true +} diff --git a/internal/utils/ssh_keys.go b/internal/utils/ssh_keys.go index f1441b3..ef762d1 100644 --- a/internal/utils/ssh_keys.go +++ b/internal/utils/ssh_keys.go @@ -1,85 +1,85 @@ -package utils - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "fmt" - "sync" - - "golang.org/x/crypto/ssh" -) - -func GenerateKeyPair(bitSize int) ([]byte, []byte, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, bitSize) - if err != nil { - return nil, nil, fmt.Errorf("GenerateKey: error generating RSA key -> %w", err) - } - - err = privateKey.Validate() - if err != nil { - return nil, nil, fmt.Errorf("GenerateKey: error validating private key -> %w", err) - } - - publicKey, err := generatePublicKey(&privateKey.PublicKey) - if err != nil { - return nil, nil, fmt.Errorf("GenerateKey: error encoding to byte public key -> %w", err) - } - - encoded := encodePrivateKeyToPEM(privateKey) - - return encoded, publicKey, nil -} - -func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { - privDER := x509.MarshalPKCS1PrivateKey(privateKey) - - privBlock := pem.Block{ - Type: "RSA PRIVATE KEY", - Headers: nil, - Bytes: privDER, - } - - privatePEM := pem.EncodeToMemory(&privBlock) - - return privatePEM -} - -func generatePublicKey(privatekey *rsa.PublicKey) ([]byte, error) { - publicRsaKey, err := ssh.NewPublicKey(privatekey) - if err != nil { - return nil, fmt.Errorf("generatePublicKey: error creating new public key from private key -> %w", err) - } - - pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey) - - return pubKeyBytes, nil -} - -var pubKeyCache sync.Map - -func GeneratePublicKeyFromPrivateKey(encodedPrivateKey []byte) ([]byte, error) { - cached, ok := pubKeyCache.Load(string(encodedPrivateKey)) - if ok { - return cached.([]byte), nil - } - - block, _ := pem.Decode(encodedPrivateKey) - if block == nil || block.Type != "RSA PRIVATE KEY" { - return nil, fmt.Errorf("GeneratePublicKeyFromPrivateKey: invalid private key type or format") - } - - privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, fmt.Errorf("GeneratePublicKeyFromPrivateKey: error parsing private key -> %w", err) - } - - publicKey, err := generatePublicKey(&privateKey.PublicKey) - if err != nil { - return nil, fmt.Errorf("GeneratePublicKeyFromPrivateKey: error generating public key -> %w", err) - } - - pubKeyCache.Store(string(encodedPrivateKey), publicKey) - return publicKey, nil -} +package utils + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "sync" + + "golang.org/x/crypto/ssh" +) + +func GenerateKeyPair(bitSize int) ([]byte, []byte, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, bitSize) + if err != nil { + return nil, nil, fmt.Errorf("GenerateKey: error generating RSA key -> %w", err) + } + + err = privateKey.Validate() + if err != nil { + return nil, nil, fmt.Errorf("GenerateKey: error validating private key -> %w", err) + } + + publicKey, err := generatePublicKey(&privateKey.PublicKey) + if err != nil { + return nil, nil, fmt.Errorf("GenerateKey: error encoding to byte public key -> %w", err) + } + + encoded := encodePrivateKeyToPEM(privateKey) + + return encoded, publicKey, nil +} + +func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { + privDER := x509.MarshalPKCS1PrivateKey(privateKey) + + privBlock := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: privDER, + } + + privatePEM := pem.EncodeToMemory(&privBlock) + + return privatePEM +} + +func generatePublicKey(privatekey *rsa.PublicKey) ([]byte, error) { + publicRsaKey, err := ssh.NewPublicKey(privatekey) + if err != nil { + return nil, fmt.Errorf("generatePublicKey: error creating new public key from private key -> %w", err) + } + + pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey) + + return pubKeyBytes, nil +} + +var pubKeyCache sync.Map + +func GeneratePublicKeyFromPrivateKey(encodedPrivateKey []byte) ([]byte, error) { + cached, ok := pubKeyCache.Load(string(encodedPrivateKey)) + if ok { + return cached.([]byte), nil + } + + block, _ := pem.Decode(encodedPrivateKey) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, fmt.Errorf("GeneratePublicKeyFromPrivateKey: invalid private key type or format") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("GeneratePublicKeyFromPrivateKey: error parsing private key -> %w", err) + } + + publicKey, err := generatePublicKey(&privateKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("GeneratePublicKeyFromPrivateKey: error generating public key -> %w", err) + } + + pubKeyCache.Store(string(encodedPrivateKey), publicKey) + return publicKey, nil +} diff --git a/internal/utils/task_path.go b/internal/utils/task_path.go index f4b50c7..0b7ae29 100644 --- a/internal/utils/task_path.go +++ b/internal/utils/task_path.go @@ -1,18 +1,18 @@ -package utils - -import ( - "fmt" - "strings" -) - -func GetTaskLogPath(upid string) string { - upidSplit := strings.Split(upid, ":") - if len(upidSplit) < 4 { - return "" - } - parsed := upidSplit[3] - logFolder := parsed[len(parsed)-2:] - logFilePath := fmt.Sprintf("/var/log/proxmox-backup/tasks/%s/%s", logFolder, upid) - - return logFilePath -} +package utils + +import ( + "fmt" + "strings" +) + +func GetTaskLogPath(upid string) string { + upidSplit := strings.Split(upid, ":") + if len(upidSplit) < 4 { + return "" + } + parsed := upidSplit[3] + logFolder := parsed[len(parsed)-2:] + logFilePath := fmt.Sprintf("/var/log/proxmox-backup/tasks/%s/%s", logFolder, upid) + + return logFilePath +} diff --git a/internal/utils/types.go b/internal/utils/types.go index b3846e6..5c016f3 100644 --- a/internal/utils/types.go +++ b/internal/utils/types.go @@ -1,14 +1,14 @@ -package utils - -type DriveInfo struct { - Letter string `json:"letter"` - Type string `json:"type"` - VolumeName string `json:"volume_name"` - FileSystem string `json:"filesystem"` - TotalBytes uint64 `json:"total_bytes"` - UsedBytes uint64 `json:"used_bytes"` - FreeBytes uint64 `json:"free_bytes"` - Total string `json:"total"` - Used string `json:"used"` - Free string `json:"free"` -} +package utils + +type DriveInfo struct { + Letter string `json:"letter"` + Type string `json:"type"` + VolumeName string `json:"volume_name"` + FileSystem string `json:"filesystem"` + TotalBytes uint64 `json:"total_bytes"` + UsedBytes uint64 `json:"used_bytes"` + FreeBytes uint64 `json:"free_bytes"` + Total string `json:"total"` + Used string `json:"used"` + Free string `json:"free"` +} diff --git a/internal/websockets/client.go b/internal/websockets/client.go index aec0a98..25200db 100644 --- a/internal/websockets/client.go +++ b/internal/websockets/client.go @@ -1,321 +1,321 @@ -package websockets - -import ( - "context" - "crypto/tls" - "fmt" - "net/http" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" -) - -const ( - maxRetryAttempts = 10 - messageTimeout = 5 * time.Second - operationTimeout = 10 * time.Second - maxMessageSize = 1024 * 1024 // 1MB - handlerPoolSize = 100 // Max concurrent message handlers -) - -type ( - MessageHandler func(ctx context.Context, msg *Message) error - - Config struct { - ServerURL string - ClientID string - Headers http.Header - TLSConfig *tls.Config - } - - WSClient struct { - config Config - - // Connection management - conn *websocket.Conn - connMu sync.RWMutex - isConnected atomic.Bool - - // Message handling - handlers map[string]MessageHandler - handlerMu sync.RWMutex - workerPool *WorkerPool - - // State management - ctx context.Context - cancel context.CancelFunc - closeOnce sync.Once - } - - WorkerPool struct { - workers chan struct{} - wg sync.WaitGroup - } -) - -func NewWorkerPool(size int) *WorkerPool { - return &WorkerPool{ - workers: make(chan struct{}, size), - } -} - -func (p *WorkerPool) Submit(ctx context.Context, task func()) error { - select { - case p.workers <- struct{}{}: - p.wg.Add(1) - go func() { - defer p.wg.Done() - defer func() { <-p.workers }() - task() - }() - return nil - case <-ctx.Done(): - return ctx.Err() - default: - return fmt.Errorf("worker pool full") - } -} - -func (p *WorkerPool) Wait() { - p.wg.Wait() -} - -func NewWSClient(ctx context.Context, config Config) (*WSClient, error) { - ctx, cancel := context.WithCancel(ctx) - - client := &WSClient{ - config: config, - handlers: make(map[string]MessageHandler), - workerPool: NewWorkerPool(handlerPoolSize), - ctx: ctx, - cancel: cancel, - } - - syslog.L.Infof("Initializing WebSocket client | server=%s client_id=%s", - client.config.ServerURL, client.config.ClientID) - return client, nil -} - -func (c *WSClient) Connect(ctx context.Context) error { - c.connMu.Lock() - defer c.connMu.Unlock() - - if c.isConnected.Load() { - syslog.L.Infof("Connection attempt skipped - already connected | client_id=%s", c.config.ClientID) - return nil - } - - syslog.L.Infof("Attempting WebSocket connection | server=%s client_id=%s", - c.config.ServerURL, c.config.ClientID) - - conn, _, err := websocket.Dial(ctx, c.config.ServerURL, &websocket.DialOptions{ - Subprotocols: []string{"pbs"}, - HTTPHeader: c.config.Headers, - HTTPClient: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: c.config.TLSConfig, - }, - }, - }) - - if err != nil { - syslog.L.Errorf("WebSocket connection failed | server=%s client_id=%s error=%v", - c.config.ServerURL, c.config.ClientID, err) - return fmt.Errorf("dial failed: %w", err) - } - - c.conn = conn - c.isConnected.Store(true) - syslog.L.Infof("WebSocket connection established | server=%s client_id=%s", - c.config.ServerURL, c.config.ClientID) - - // Start message handler in background - go c.handleMessages() - - return nil -} - -func (c *WSClient) handleMessages() { - for { - select { - case <-c.ctx.Done(): - syslog.L.Infof("Message handler stopping - context cancelled | client_id=%s", - c.config.ClientID) - return - default: - var message Message - err := wsjson.Read(c.ctx, c.conn, &message) - if err != nil { - // Filter out EOF errors and normal closure - if !isNormalClosureError(err) { - c.handleConnectionError(err) - } - return - } - - syslog.L.Infof("Received message | type=%s client_id=%s", - message.Type, c.config.ClientID) - c.handleMessage(&message) - } - } -} - -func (c *WSClient) handleMessage(msg *Message) { - c.handlerMu.RLock() - handler, exists := c.handlers[msg.Type] - c.handlerMu.RUnlock() - - if !exists { - syslog.L.Warnf("No handler registered | message_type=%s client_id=%s", - msg.Type, c.config.ClientID) - return - } - - ctx, cancel := context.WithTimeout(c.ctx, messageTimeout) - defer cancel() - - err := c.workerPool.Submit(ctx, func() { - start := time.Now() - if err := handler(ctx, msg); err != nil { - syslog.L.Errorf("Message handler failed | type=%s client_id=%s error=%v duration=%v", - msg.Type, c.config.ClientID, err, time.Since(start)) - } else { - syslog.L.Infof("Message handled successfully | type=%s client_id=%s duration=%v", - msg.Type, c.config.ClientID, time.Since(start)) - } - }) - - if err != nil { - syslog.L.Warnf("Worker pool submission failed | type=%s client_id=%s error=%v", - msg.Type, c.config.ClientID, err) - } -} - -func (c *WSClient) handleConnectionError(err error) { - if isNormalClosureError(err) { - syslog.L.Infof("WebSocket connection closed normally | client_id=%s", c.config.ClientID) - return - } - - syslog.L.Errorf("WebSocket connection error | client_id=%s error=%v", - c.config.ClientID, err) - c.isConnected.Store(false) - - // Attempt reconnection with backoff - for attempt := 0; attempt < maxRetryAttempts; attempt++ { - select { - case <-c.ctx.Done(): - return - case <-time.After(backoff(attempt)): - syslog.L.Infof("Attempting reconnection | attempt=%d/%d client_id=%s", - attempt+1, maxRetryAttempts, c.config.ClientID) - - ctx, cancel := context.WithTimeout(c.ctx, operationTimeout) - if err := c.Connect(ctx); err == nil { - syslog.L.Infof("Reconnection successful | client_id=%s attempt=%d", - c.config.ClientID, attempt+1) - cancel() - return - } - cancel() - } - } - - syslog.L.Errorf("Reconnection failed after %d attempts | client_id=%s", - maxRetryAttempts, c.config.ClientID) -} - -func (c *WSClient) Send(ctx context.Context, msg Message) error { - if !c.isConnected.Load() { - return fmt.Errorf("not connected") - } - - c.connMu.RLock() - defer c.connMu.RUnlock() - - start := time.Now() - err := wsjson.Write(ctx, c.conn, &msg) - if err != nil { - syslog.L.Errorf("Failed to send message | type=%s client_id=%s error=%v duration=%v", - msg.Type, c.config.ClientID, err, time.Since(start)) - return err - } - - syslog.L.Infof("Message sent successfully | type=%s client_id=%s duration=%v", - msg.Type, c.config.ClientID, time.Since(start)) - return nil -} - -func (c *WSClient) RegisterHandler(msgType string, handler MessageHandler) UnregisterFunc { - c.handlerMu.Lock() - c.handlers[msgType] = handler - c.handlerMu.Unlock() - - syslog.L.Infof("Registered message handler | type=%s client_id=%s", - msgType, c.config.ClientID) - - return func() { - c.handlerMu.Lock() - defer c.handlerMu.Unlock() - - if _, exists := c.handlers[msgType]; exists { - delete(c.handlers, msgType) - syslog.L.Infof("Unregistered message handler | type=%s client_id=%s", - msgType, c.config.ClientID) - } else { - syslog.L.Warnf("Attempted to unregister non-existent handler | type=%s client_id=%s", - msgType, c.config.ClientID) - } - } -} - -func (c *WSClient) Close() error { - var closeErr error - c.closeOnce.Do(func() { - syslog.L.Infof("Closing WebSocket client | client_id=%s", c.config.ClientID) - c.cancel() - - c.connMu.Lock() - defer c.connMu.Unlock() - - if c.conn != nil { - closeErr = c.conn.Close(websocket.StatusNormalClosure, "client closing") - if closeErr != nil { - syslog.L.Errorf("Error closing connection | client_id=%s error=%v", - c.config.ClientID, closeErr) - } - } - - c.isConnected.Store(false) - c.workerPool.Wait() - syslog.L.Infof("WebSocket client closed | client_id=%s", c.config.ClientID) - }) - - return closeErr -} - -// Helper function to identify normal closure errors -func isNormalClosureError(err error) bool { - return websocket.CloseStatus(err) == websocket.StatusNormalClosure || - strings.Contains(err.Error(), "context canceled") || - strings.Contains(err.Error(), "EOF") -} - -func backoff(attempt int) time.Duration { - base := time.Second - max := 30 * time.Second - duration := time.Duration(1< max { - duration = max - } - return duration -} - -func (c *WSClient) GetConnectionStatus() bool { - return c.isConnected.Load() -} +package websockets + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" +) + +const ( + maxRetryAttempts = 10 + messageTimeout = 5 * time.Second + operationTimeout = 10 * time.Second + maxMessageSize = 1024 * 1024 // 1MB + handlerPoolSize = 100 // Max concurrent message handlers +) + +type ( + MessageHandler func(ctx context.Context, msg *Message) error + + Config struct { + ServerURL string + ClientID string + Headers http.Header + TLSConfig *tls.Config + } + + WSClient struct { + config Config + + // Connection management + conn *websocket.Conn + connMu sync.RWMutex + isConnected atomic.Bool + + // Message handling + handlers map[string]MessageHandler + handlerMu sync.RWMutex + workerPool *WorkerPool + + // State management + ctx context.Context + cancel context.CancelFunc + closeOnce sync.Once + } + + WorkerPool struct { + workers chan struct{} + wg sync.WaitGroup + } +) + +func NewWorkerPool(size int) *WorkerPool { + return &WorkerPool{ + workers: make(chan struct{}, size), + } +} + +func (p *WorkerPool) Submit(ctx context.Context, task func()) error { + select { + case p.workers <- struct{}{}: + p.wg.Add(1) + go func() { + defer p.wg.Done() + defer func() { <-p.workers }() + task() + }() + return nil + case <-ctx.Done(): + return ctx.Err() + default: + return fmt.Errorf("worker pool full") + } +} + +func (p *WorkerPool) Wait() { + p.wg.Wait() +} + +func NewWSClient(ctx context.Context, config Config) (*WSClient, error) { + ctx, cancel := context.WithCancel(ctx) + + client := &WSClient{ + config: config, + handlers: make(map[string]MessageHandler), + workerPool: NewWorkerPool(handlerPoolSize), + ctx: ctx, + cancel: cancel, + } + + syslog.L.Infof("Initializing WebSocket client | server=%s client_id=%s", + client.config.ServerURL, client.config.ClientID) + return client, nil +} + +func (c *WSClient) Connect(ctx context.Context) error { + c.connMu.Lock() + defer c.connMu.Unlock() + + if c.isConnected.Load() { + syslog.L.Infof("Connection attempt skipped - already connected | client_id=%s", c.config.ClientID) + return nil + } + + syslog.L.Infof("Attempting WebSocket connection | server=%s client_id=%s", + c.config.ServerURL, c.config.ClientID) + + conn, _, err := websocket.Dial(ctx, c.config.ServerURL, &websocket.DialOptions{ + Subprotocols: []string{"pbs"}, + HTTPHeader: c.config.Headers, + HTTPClient: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: c.config.TLSConfig, + }, + }, + }) + + if err != nil { + syslog.L.Errorf("WebSocket connection failed | server=%s client_id=%s error=%v", + c.config.ServerURL, c.config.ClientID, err) + return fmt.Errorf("dial failed: %w", err) + } + + c.conn = conn + c.isConnected.Store(true) + syslog.L.Infof("WebSocket connection established | server=%s client_id=%s", + c.config.ServerURL, c.config.ClientID) + + // Start message handler in background + go c.handleMessages() + + return nil +} + +func (c *WSClient) handleMessages() { + for { + select { + case <-c.ctx.Done(): + syslog.L.Infof("Message handler stopping - context cancelled | client_id=%s", + c.config.ClientID) + return + default: + var message Message + err := wsjson.Read(c.ctx, c.conn, &message) + if err != nil { + // Filter out EOF errors and normal closure + if !isNormalClosureError(err) { + c.handleConnectionError(err) + } + return + } + + syslog.L.Infof("Received message | type=%s client_id=%s", + message.Type, c.config.ClientID) + c.handleMessage(&message) + } + } +} + +func (c *WSClient) handleMessage(msg *Message) { + c.handlerMu.RLock() + handler, exists := c.handlers[msg.Type] + c.handlerMu.RUnlock() + + if !exists { + syslog.L.Warnf("No handler registered | message_type=%s client_id=%s", + msg.Type, c.config.ClientID) + return + } + + ctx, cancel := context.WithTimeout(c.ctx, messageTimeout) + defer cancel() + + err := c.workerPool.Submit(ctx, func() { + start := time.Now() + if err := handler(ctx, msg); err != nil { + syslog.L.Errorf("Message handler failed | type=%s client_id=%s error=%v duration=%v", + msg.Type, c.config.ClientID, err, time.Since(start)) + } else { + syslog.L.Infof("Message handled successfully | type=%s client_id=%s duration=%v", + msg.Type, c.config.ClientID, time.Since(start)) + } + }) + + if err != nil { + syslog.L.Warnf("Worker pool submission failed | type=%s client_id=%s error=%v", + msg.Type, c.config.ClientID, err) + } +} + +func (c *WSClient) handleConnectionError(err error) { + if isNormalClosureError(err) { + syslog.L.Infof("WebSocket connection closed normally | client_id=%s", c.config.ClientID) + return + } + + syslog.L.Errorf("WebSocket connection error | client_id=%s error=%v", + c.config.ClientID, err) + c.isConnected.Store(false) + + // Attempt reconnection with backoff + for attempt := 0; attempt < maxRetryAttempts; attempt++ { + select { + case <-c.ctx.Done(): + return + case <-time.After(backoff(attempt)): + syslog.L.Infof("Attempting reconnection | attempt=%d/%d client_id=%s", + attempt+1, maxRetryAttempts, c.config.ClientID) + + ctx, cancel := context.WithTimeout(c.ctx, operationTimeout) + if err := c.Connect(ctx); err == nil { + syslog.L.Infof("Reconnection successful | client_id=%s attempt=%d", + c.config.ClientID, attempt+1) + cancel() + return + } + cancel() + } + } + + syslog.L.Errorf("Reconnection failed after %d attempts | client_id=%s", + maxRetryAttempts, c.config.ClientID) +} + +func (c *WSClient) Send(ctx context.Context, msg Message) error { + if !c.isConnected.Load() { + return fmt.Errorf("not connected") + } + + c.connMu.RLock() + defer c.connMu.RUnlock() + + start := time.Now() + err := wsjson.Write(ctx, c.conn, &msg) + if err != nil { + syslog.L.Errorf("Failed to send message | type=%s client_id=%s error=%v duration=%v", + msg.Type, c.config.ClientID, err, time.Since(start)) + return err + } + + syslog.L.Infof("Message sent successfully | type=%s client_id=%s duration=%v", + msg.Type, c.config.ClientID, time.Since(start)) + return nil +} + +func (c *WSClient) RegisterHandler(msgType string, handler MessageHandler) UnregisterFunc { + c.handlerMu.Lock() + c.handlers[msgType] = handler + c.handlerMu.Unlock() + + syslog.L.Infof("Registered message handler | type=%s client_id=%s", + msgType, c.config.ClientID) + + return func() { + c.handlerMu.Lock() + defer c.handlerMu.Unlock() + + if _, exists := c.handlers[msgType]; exists { + delete(c.handlers, msgType) + syslog.L.Infof("Unregistered message handler | type=%s client_id=%s", + msgType, c.config.ClientID) + } else { + syslog.L.Warnf("Attempted to unregister non-existent handler | type=%s client_id=%s", + msgType, c.config.ClientID) + } + } +} + +func (c *WSClient) Close() error { + var closeErr error + c.closeOnce.Do(func() { + syslog.L.Infof("Closing WebSocket client | client_id=%s", c.config.ClientID) + c.cancel() + + c.connMu.Lock() + defer c.connMu.Unlock() + + if c.conn != nil { + closeErr = c.conn.Close(websocket.StatusNormalClosure, "client closing") + if closeErr != nil { + syslog.L.Errorf("Error closing connection | client_id=%s error=%v", + c.config.ClientID, closeErr) + } + } + + c.isConnected.Store(false) + c.workerPool.Wait() + syslog.L.Infof("WebSocket client closed | client_id=%s", c.config.ClientID) + }) + + return closeErr +} + +// Helper function to identify normal closure errors +func isNormalClosureError(err error) bool { + return websocket.CloseStatus(err) == websocket.StatusNormalClosure || + strings.Contains(err.Error(), "context canceled") || + strings.Contains(err.Error(), "EOF") +} + +func backoff(attempt int) time.Duration { + base := time.Second + max := 30 * time.Second + duration := time.Duration(1< max { + duration = max + } + return duration +} + +func (c *WSClient) GetConnectionStatus() bool { + return c.isConnected.Load() +} diff --git a/internal/websockets/server.go b/internal/websockets/server.go index aa4b17d..ac0ab97 100644 --- a/internal/websockets/server.go +++ b/internal/websockets/server.go @@ -1,353 +1,353 @@ -package websockets - -import ( - "context" - "fmt" - "net/http" - "sync" - "time" - - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" - "github.com/sonroyaalmerol/pbs-plus/internal/syslog" -) - -const ( - workerPoolSize = 100 -) - -type ( - Message struct { - ClientID string `json:"client_id"` - Type string `json:"type"` - Content string `json:"content"` - Time time.Time `json:"time"` - } - - Client struct { - ID string - AgentVersion string - conn *websocket.Conn - server *Server - ctx context.Context - cancel context.CancelFunc - closeOnce sync.Once - } - - Server struct { - // Client management - clients map[string]*Client - clientsMu sync.RWMutex - - // Message handling - handlers map[string][]MessageHandler - handlerMu sync.RWMutex - workers *WorkerPool - - // Context management - ctx context.Context - cancel context.CancelFunc - } - - ServerOption func(*Server) -) - -func WithWorkerPoolSize(size int) ServerOption { - return func(s *Server) { - s.workers = NewWorkerPool(size) - syslog.L.Infof("Worker pool size configured | size=%d", size) - } -} - -func NewServer(ctx context.Context, opts ...ServerOption) *Server { - ctx, cancel := context.WithCancel(ctx) - - s := &Server{ - clients: make(map[string]*Client), - handlers: make(map[string][]MessageHandler), - workers: NewWorkerPool(workerPoolSize), - ctx: ctx, - cancel: cancel, - } - - syslog.L.Info("Initializing WebSocket server") - - for _, opt := range opts { - opt(s) - } - - syslog.L.Info("WebSocket server initialized successfully") - return s -} - -type UnregisterFunc func() - -func (s *Server) RegisterHandler(msgType string, handler MessageHandler) UnregisterFunc { - s.handlerMu.Lock() - currentHandlers := s.handlers[msgType] - handlerIndex := len(currentHandlers) - s.handlers[msgType] = append(currentHandlers, handler) - s.handlerMu.Unlock() - - syslog.L.Infof("Registered message handler | type=%s handler_count=%d", - msgType, handlerIndex+1) - - return func() { - s.handlerMu.Lock() - defer s.handlerMu.Unlock() - - handlers := s.handlers[msgType] - if handlerIndex >= len(handlers) { - syslog.L.Warnf("Handler already unregistered | type=%s handler_index=%d", - msgType, handlerIndex) - return - } - - newHandlers := make([]MessageHandler, 0, len(handlers)-1) - newHandlers = append(newHandlers, handlers[:handlerIndex]...) - if handlerIndex+1 < len(handlers) { - newHandlers = append(newHandlers, handlers[handlerIndex+1:]...) - } - - if len(newHandlers) == 0 { - delete(s.handlers, msgType) - } else { - s.handlers[msgType] = newHandlers - } - - syslog.L.Infof("Unregistered message handler | type=%s remaining_handlers=%d", - msgType, len(newHandlers)) - } -} - -func (s *Server) handleMessage(msg *Message) { - s.handlerMu.RLock() - handlers, exists := s.handlers[msg.Type] - handlerCount := len(handlers) - s.handlerMu.RUnlock() - - if !exists { - syslog.L.Warnf("No handlers registered | message_type=%s client_id=%s", - msg.Type, msg.ClientID) - return - } - - syslog.L.Infof("Processing message | type=%s client_id=%s handler_count=%d", - msg.Type, msg.ClientID, handlerCount) - - for i, handler := range handlers { - handler := handler // Create new variable for closure - ctx, cancel := context.WithTimeout(s.ctx, messageTimeout) - - start := time.Now() - err := s.workers.Submit(ctx, func() { - if err := handler(ctx, msg); err != nil { - syslog.L.Errorf("Handler error | type=%s client_id=%s handler_index=%d error=%v duration=%v", - msg.Type, msg.ClientID, i, err, time.Since(start)) - } else { - syslog.L.Infof("Handler completed | type=%s client_id=%s handler_index=%d duration=%v", - msg.Type, msg.ClientID, i, time.Since(start)) - } - }) - - if err != nil { - syslog.L.Errorf("Worker pool submission failed | type=%s client_id=%s handler_index=%d error=%v", - msg.Type, msg.ClientID, i, err) - } - - cancel() - } -} - -func (s *Server) ServeWS(w http.ResponseWriter, r *http.Request) { - clientID := r.Header.Get("X-PBS-Agent") - if clientID == "" { - syslog.L.Warnf("Rejected WebSocket connection | reason=missing_client_id remote_addr=%s", - r.RemoteAddr) - w.WriteHeader(http.StatusBadRequest) - return - } - - clientVersion := r.Header.Get("X-PBS-Plus-Version") - syslog.L.Infof("WebSocket connection request | client_id=%s version=%s remote_addr=%s", - clientID, clientVersion, r.RemoteAddr) - - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"pbs"}, - }) - if err != nil { - syslog.L.Errorf("WebSocket acceptance failed | client_id=%s error=%v", clientID, err) - return - } - - ctx, cancel := context.WithCancel(s.ctx) - client := &Client{ - ID: clientID, - AgentVersion: clientVersion, - conn: conn, - server: s, - ctx: ctx, - cancel: cancel, - } - - s.registerClient(client) - go s.handleClientConnection(client) - - <-client.ctx.Done() - s.unregisterClient(client) -} - -func (s *Server) handleClientConnection(client *Client) { - syslog.L.Infof("Starting client connection handler | client_id=%s version=%s", - client.ID, client.AgentVersion) - - defer syslog.L.Infof("Client connection handler stopped | client_id=%s", client.ID) - - for { - select { - case <-s.ctx.Done(): - return - case <-client.ctx.Done(): - return - default: - var msg Message - if err := wsjson.Read(client.ctx, client.conn, &msg); err != nil { - if !isNormalClosureError(err) { - syslog.L.Errorf("Message read error | client_id=%s error=%v", - client.ID, err) - } else { - syslog.L.Infof("Client connection closed normally | client_id=%s", - client.ID) - } - return - } - - msg.ClientID = client.ID - msg.Time = time.Now() - - syslog.L.Infof("Received message | type=%s client_id=%s", msg.Type, msg.ClientID) - s.handleMessage(&msg) - } - } -} - -func (s *Server) registerClient(client *Client) { - s.clientsMu.Lock() - s.clients[client.ID] = client - clientCount := len(s.clients) - s.clientsMu.Unlock() - - syslog.L.Infof("Client registered | id=%s version=%s total_clients=%d", - client.ID, client.AgentVersion, clientCount) -} - -func (s *Server) unregisterClient(client *Client) { - if client == nil { - return - } - - s.clientsMu.Lock() - if _, exists := s.clients[client.ID]; exists { - client.close() - delete(s.clients, client.ID) - clientCount := len(s.clients) - s.clientsMu.Unlock() - - syslog.L.Infof("Client unregistered | id=%s total_clients=%d", - client.ID, clientCount) - } else { - s.clientsMu.Unlock() - syslog.L.Warnf("Attempted to unregister non-existent client | id=%s", - client.ID) - } -} - -func (c *Client) close() { - c.closeOnce.Do(func() { - syslog.L.Infof("Closing client connection | id=%s", c.ID) - c.cancel() - if err := c.conn.Close(websocket.StatusNormalClosure, "client disconnecting"); err != nil { - syslog.L.Errorf("Error closing client connection | id=%s error=%v", c.ID, err) - } - }) -} - -func (s *Server) SendToClient(clientID string, msg Message) error { - s.clientsMu.RLock() - client, exists := s.clients[clientID] - s.clientsMu.RUnlock() - - if !exists { - return fmt.Errorf("client %s not connected", clientID) - } - - ctx, cancel := context.WithTimeout(client.ctx, messageTimeout) - defer cancel() - - start := time.Now() - err := wsjson.Write(ctx, client.conn, &msg) - if err != nil { - syslog.L.Errorf("Failed to send message | client_id=%s type=%s error=%v duration=%v", - clientID, msg.Type, err, time.Since(start)) - return err - } - - syslog.L.Infof("Message sent successfully | client_id=%s type=%s duration=%v", - clientID, msg.Type, time.Since(start)) - return nil -} - -func (s *Server) Shutdown(ctx context.Context) error { - syslog.L.Info("Starting server shutdown") - s.cancel() - - done := make(chan struct{}) - go func() { - s.cleanup() - close(done) - }() - - select { - case <-done: - syslog.L.Info("Server shutdown completed successfully") - return nil - case <-ctx.Done(): - syslog.L.Errorf("Server shutdown timed out | error=%v", ctx.Err()) - return ctx.Err() - } -} - -func (s *Server) cleanup() { - start := time.Now() - - s.clientsMu.Lock() - clientCount := len(s.clients) - for _, client := range s.clients { - client.close() - } - clear(s.clients) - s.clientsMu.Unlock() - - s.workers.Wait() - - syslog.L.Infof("Cleanup completed | clients_closed=%d duration=%v", - clientCount, time.Since(start)) -} - -func (s *Server) IsClientConnected(clientID string) bool { - s.clientsMu.RLock() - _, exists := s.clients[clientID] - s.clientsMu.RUnlock() - return exists -} - -func (s *Server) GetClientVersion(clientID string) string { - s.clientsMu.RLock() - client, exists := s.clients[clientID] - s.clientsMu.RUnlock() - - if exists { - return client.AgentVersion - } - return "" -} +package websockets + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/sonroyaalmerol/pbs-plus/internal/syslog" +) + +const ( + workerPoolSize = 100 +) + +type ( + Message struct { + ClientID string `json:"client_id"` + Type string `json:"type"` + Content string `json:"content"` + Time time.Time `json:"time"` + } + + Client struct { + ID string + AgentVersion string + conn *websocket.Conn + server *Server + ctx context.Context + cancel context.CancelFunc + closeOnce sync.Once + } + + Server struct { + // Client management + clients map[string]*Client + clientsMu sync.RWMutex + + // Message handling + handlers map[string][]MessageHandler + handlerMu sync.RWMutex + workers *WorkerPool + + // Context management + ctx context.Context + cancel context.CancelFunc + } + + ServerOption func(*Server) +) + +func WithWorkerPoolSize(size int) ServerOption { + return func(s *Server) { + s.workers = NewWorkerPool(size) + syslog.L.Infof("Worker pool size configured | size=%d", size) + } +} + +func NewServer(ctx context.Context, opts ...ServerOption) *Server { + ctx, cancel := context.WithCancel(ctx) + + s := &Server{ + clients: make(map[string]*Client), + handlers: make(map[string][]MessageHandler), + workers: NewWorkerPool(workerPoolSize), + ctx: ctx, + cancel: cancel, + } + + syslog.L.Info("Initializing WebSocket server") + + for _, opt := range opts { + opt(s) + } + + syslog.L.Info("WebSocket server initialized successfully") + return s +} + +type UnregisterFunc func() + +func (s *Server) RegisterHandler(msgType string, handler MessageHandler) UnregisterFunc { + s.handlerMu.Lock() + currentHandlers := s.handlers[msgType] + handlerIndex := len(currentHandlers) + s.handlers[msgType] = append(currentHandlers, handler) + s.handlerMu.Unlock() + + syslog.L.Infof("Registered message handler | type=%s handler_count=%d", + msgType, handlerIndex+1) + + return func() { + s.handlerMu.Lock() + defer s.handlerMu.Unlock() + + handlers := s.handlers[msgType] + if handlerIndex >= len(handlers) { + syslog.L.Warnf("Handler already unregistered | type=%s handler_index=%d", + msgType, handlerIndex) + return + } + + newHandlers := make([]MessageHandler, 0, len(handlers)-1) + newHandlers = append(newHandlers, handlers[:handlerIndex]...) + if handlerIndex+1 < len(handlers) { + newHandlers = append(newHandlers, handlers[handlerIndex+1:]...) + } + + if len(newHandlers) == 0 { + delete(s.handlers, msgType) + } else { + s.handlers[msgType] = newHandlers + } + + syslog.L.Infof("Unregistered message handler | type=%s remaining_handlers=%d", + msgType, len(newHandlers)) + } +} + +func (s *Server) handleMessage(msg *Message) { + s.handlerMu.RLock() + handlers, exists := s.handlers[msg.Type] + handlerCount := len(handlers) + s.handlerMu.RUnlock() + + if !exists { + syslog.L.Warnf("No handlers registered | message_type=%s client_id=%s", + msg.Type, msg.ClientID) + return + } + + syslog.L.Infof("Processing message | type=%s client_id=%s handler_count=%d", + msg.Type, msg.ClientID, handlerCount) + + for i, handler := range handlers { + handler := handler // Create new variable for closure + ctx, cancel := context.WithTimeout(s.ctx, messageTimeout) + + start := time.Now() + err := s.workers.Submit(ctx, func() { + if err := handler(ctx, msg); err != nil { + syslog.L.Errorf("Handler error | type=%s client_id=%s handler_index=%d error=%v duration=%v", + msg.Type, msg.ClientID, i, err, time.Since(start)) + } else { + syslog.L.Infof("Handler completed | type=%s client_id=%s handler_index=%d duration=%v", + msg.Type, msg.ClientID, i, time.Since(start)) + } + }) + + if err != nil { + syslog.L.Errorf("Worker pool submission failed | type=%s client_id=%s handler_index=%d error=%v", + msg.Type, msg.ClientID, i, err) + } + + cancel() + } +} + +func (s *Server) ServeWS(w http.ResponseWriter, r *http.Request) { + clientID := r.Header.Get("X-PBS-Agent") + if clientID == "" { + syslog.L.Warnf("Rejected WebSocket connection | reason=missing_client_id remote_addr=%s", + r.RemoteAddr) + w.WriteHeader(http.StatusBadRequest) + return + } + + clientVersion := r.Header.Get("X-PBS-Plus-Version") + syslog.L.Infof("WebSocket connection request | client_id=%s version=%s remote_addr=%s", + clientID, clientVersion, r.RemoteAddr) + + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"pbs"}, + }) + if err != nil { + syslog.L.Errorf("WebSocket acceptance failed | client_id=%s error=%v", clientID, err) + return + } + + ctx, cancel := context.WithCancel(s.ctx) + client := &Client{ + ID: clientID, + AgentVersion: clientVersion, + conn: conn, + server: s, + ctx: ctx, + cancel: cancel, + } + + s.registerClient(client) + go s.handleClientConnection(client) + + <-client.ctx.Done() + s.unregisterClient(client) +} + +func (s *Server) handleClientConnection(client *Client) { + syslog.L.Infof("Starting client connection handler | client_id=%s version=%s", + client.ID, client.AgentVersion) + + defer syslog.L.Infof("Client connection handler stopped | client_id=%s", client.ID) + + for { + select { + case <-s.ctx.Done(): + return + case <-client.ctx.Done(): + return + default: + var msg Message + if err := wsjson.Read(client.ctx, client.conn, &msg); err != nil { + if !isNormalClosureError(err) { + syslog.L.Errorf("Message read error | client_id=%s error=%v", + client.ID, err) + } else { + syslog.L.Infof("Client connection closed normally | client_id=%s", + client.ID) + } + return + } + + msg.ClientID = client.ID + msg.Time = time.Now() + + syslog.L.Infof("Received message | type=%s client_id=%s", msg.Type, msg.ClientID) + s.handleMessage(&msg) + } + } +} + +func (s *Server) registerClient(client *Client) { + s.clientsMu.Lock() + s.clients[client.ID] = client + clientCount := len(s.clients) + s.clientsMu.Unlock() + + syslog.L.Infof("Client registered | id=%s version=%s total_clients=%d", + client.ID, client.AgentVersion, clientCount) +} + +func (s *Server) unregisterClient(client *Client) { + if client == nil { + return + } + + s.clientsMu.Lock() + if _, exists := s.clients[client.ID]; exists { + client.close() + delete(s.clients, client.ID) + clientCount := len(s.clients) + s.clientsMu.Unlock() + + syslog.L.Infof("Client unregistered | id=%s total_clients=%d", + client.ID, clientCount) + } else { + s.clientsMu.Unlock() + syslog.L.Warnf("Attempted to unregister non-existent client | id=%s", + client.ID) + } +} + +func (c *Client) close() { + c.closeOnce.Do(func() { + syslog.L.Infof("Closing client connection | id=%s", c.ID) + c.cancel() + if err := c.conn.Close(websocket.StatusNormalClosure, "client disconnecting"); err != nil { + syslog.L.Errorf("Error closing client connection | id=%s error=%v", c.ID, err) + } + }) +} + +func (s *Server) SendToClient(clientID string, msg Message) error { + s.clientsMu.RLock() + client, exists := s.clients[clientID] + s.clientsMu.RUnlock() + + if !exists { + return fmt.Errorf("client %s not connected", clientID) + } + + ctx, cancel := context.WithTimeout(client.ctx, messageTimeout) + defer cancel() + + start := time.Now() + err := wsjson.Write(ctx, client.conn, &msg) + if err != nil { + syslog.L.Errorf("Failed to send message | client_id=%s type=%s error=%v duration=%v", + clientID, msg.Type, err, time.Since(start)) + return err + } + + syslog.L.Infof("Message sent successfully | client_id=%s type=%s duration=%v", + clientID, msg.Type, time.Since(start)) + return nil +} + +func (s *Server) Shutdown(ctx context.Context) error { + syslog.L.Info("Starting server shutdown") + s.cancel() + + done := make(chan struct{}) + go func() { + s.cleanup() + close(done) + }() + + select { + case <-done: + syslog.L.Info("Server shutdown completed successfully") + return nil + case <-ctx.Done(): + syslog.L.Errorf("Server shutdown timed out | error=%v", ctx.Err()) + return ctx.Err() + } +} + +func (s *Server) cleanup() { + start := time.Now() + + s.clientsMu.Lock() + clientCount := len(s.clients) + for _, client := range s.clients { + client.close() + } + clear(s.clients) + s.clientsMu.Unlock() + + s.workers.Wait() + + syslog.L.Infof("Cleanup completed | clients_closed=%d duration=%v", + clientCount, time.Since(start)) +} + +func (s *Server) IsClientConnected(clientID string) bool { + s.clientsMu.RLock() + _, exists := s.clients[clientID] + s.clientsMu.RUnlock() + return exists +} + +func (s *Server) GetClientVersion(clientID string) string { + s.clientsMu.RLock() + client, exists := s.clients[clientID] + s.clientsMu.RUnlock() + + if exists { + return client.AgentVersion + } + return "" +}