From a91d98e1ad9945335b3b14ee4b64d79a9aa8004d Mon Sep 17 00:00:00 2001 From: Jacob Su Date: Sat, 26 Oct 2024 23:51:01 +0800 Subject: [PATCH] impl SRS /api/v1/clients api. --- proxy/api.go | 11 +++- proxy/http.go | 4 +- proxy/srs-api-proxy.go | 130 +++++++++++++++++++++++++++++++++++++++++ proxy/srs.go | 74 ++++++++++++++++++++++- proxy/sync/map.go | 10 ++++ proxy/utils.go | 15 +++-- 6 files changed, 233 insertions(+), 11 deletions(-) create mode 100644 proxy/srs-api-proxy.go diff --git a/proxy/api.go b/proxy/api.go index 04baa92526..f7f4681a1c 100644 --- a/proxy/api.go +++ b/proxy/api.go @@ -82,7 +82,7 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error { logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } }) @@ -90,10 +90,15 @@ func (v *srsHTTPAPIServer) Run(ctx context.Context) error { logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } }) + logger.Df(ctx, "Proxy /api/ to srs") + mux.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) { + srsLoadBalancer.ProxyHTTPAPI(ctx, w, r) + }) + // Run HTTP API server. v.wg.Add(1) go func() { @@ -239,7 +244,7 @@ func (v *systemAPI) Run(ctx context.Context) error { logger.Df(ctx, "Register SRS media server, %+v", server) return nil }(); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } type Response struct { diff --git a/proxy/http.go b/proxy/http.go index f02af02a30..92f5942f5f 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -198,7 +198,7 @@ func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) ctx := logger.WithContext(v.ctx) if err := v.serve(ctx, w, r); err != nil { - apiError(ctx, w, r, err) + apiError(ctx, w, r, err, http.StatusInternalServerError) } else { logger.Df(ctx, "HTTP client done") } @@ -318,7 +318,7 @@ func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() if err := v.serve(v.ctx, w, r); err != nil { - apiError(v.ctx, w, r, err) + apiError(v.ctx, w, r, err, http.StatusInternalServerError) } else { logger.Df(v.ctx, "HLS client %v for %v with %v done", v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path) diff --git a/proxy/srs-api-proxy.go b/proxy/srs-api-proxy.go new file mode 100644 index 0000000000..2b7f63cb9f --- /dev/null +++ b/proxy/srs-api-proxy.go @@ -0,0 +1,130 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/json" + "io" + "net/http" + "srs-proxy/errors" + "srs-proxy/logger" + "strings" +) + +type SrsClient struct { + Id string `json:"id"` + Vhost string `json:"vhost"` + Stream string `json:"stream"` + Ip string `json:"ip"` + PageUrl string `json:"pageUrl"` + SwfUrl string `json:"swfUrl"` + TcUrl string `json:"tcUrl"` + Url string `json:"url"` + Name string `json:"name"` + Type string `json:"type"` + Publish bool `json:"publish"` + Alive float32 `json:"alive"` + SendBytes int `json:"send_bytes"` + RecvBytes int `json:"recv_bytes"` +} + +type SrsClientResponse struct { + Code int `json:"code"` + Server string `json:"server"` + Service string `json:"service"` + Pid string `json:"pid"` + Client SrsClient `json:"client"` +} + +type SrsClientsResponse struct { + Code int `json:"code"` + Server string `json:"server"` + Service string `json:"service"` + Pid string `json:"pid"` + Clients []SrsClient `json:"clients"` +} + +type SrsClientDeleteResponse struct { + Code int `json:"code"` +} + +type SrsApiProxy struct { +} + +func (v *SrsApiProxy) proxySrsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error { + if strings.HasPrefix(r.URL.Path, "/api/v1/clients") { + return proxySrsClientsAPI(ctx, servers, w, r) + } else if strings.HasPrefix(r.URL.Path, "/api/v1/streams") { + return proxySrsStreamsAPI(ctx, servers, w, r) + } + return nil +} + +// handle srs clients api /api/v1/clients +func proxySrsClientsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + + clientId := "" + if strings.HasPrefix(r.URL.Path, "/api/v1/clients/") { + clientId = r.URL.Path[len("/api/v1/clients/"):] + } + logger.Df(ctx, "%v %v clientId=%v", r.Method, r.URL.Path, clientId) + + body, err := io.ReadAll(r.Body) + if err != nil { + apiError(ctx, w, r, err, http.StatusInternalServerError) + return errors.Wrapf(err, "read request body err") + } + + switch r.Method { + case http.MethodDelete: + for _, server := range servers { + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + logger.Df(ctx, "response %v", string(ret)) + var res SrsClientDeleteResponse + if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 { + apiResponse(ctx, w, r, res) + return nil + } + } + } + + err := errors.Errorf("clientId %v not found in server", clientId) + apiError(ctx, w, r, err, http.StatusNotFound) + return err + case http.MethodGet: + if len(clientId) > 0 { + for _, server := range servers { + var client SrsClientResponse + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + if err := json.Unmarshal(ret, &client); err == nil && client.Code == 0 { + apiResponse(ctx, w, r, client) + return nil + } + } + } + } else { // get all clients + var clients SrsClientsResponse + for _, server := range servers { + var res SrsClientsResponse + if ret, err := server.ApiRequest(ctx, r, body); err == nil { + if err := json.Unmarshal(ret, &res); err == nil && res.Code == 0 { + clients.Clients = append(clients.Clients, res.Clients...) + } + } + } + + apiResponse(ctx, w, r, clients) + return nil + } + default: + logger.Df(ctx, "/api/v1/clients %v", r.Method) + } + return nil +} + +func proxySrsStreamsAPI(ctx context.Context, servers []*SRSServer, w http.ResponseWriter, r *http.Request) error { + return nil +} diff --git a/proxy/srs.go b/proxy/srs.go index d05a39c610..b04df9ed79 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -4,10 +4,13 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" + "io" "math/rand" + "net/http" "os" "strconv" "strings" @@ -97,6 +100,28 @@ func (v *SRSServer) Format(f fmt.State, c rune) { } } +func (v *SRSServer) ApiRequest(ctx context.Context, r *http.Request, body []byte) ([]byte, error) { + url := "http://" + v.IP + ":" + v.API[0] + r.URL.Path + if r.URL.RawQuery != "" { + url += "?" + r.URL.RawQuery + } + + if req, err := http.NewRequestWithContext(ctx, r.Method, url, bytes.NewReader(body)); err != nil { + return nil, errors.Wrapf(err, "create request to %v", url) + } else if res, err := http.DefaultClient.Do(req); err != nil { + return nil, errors.Wrapf(err, "send request to %v", url) + } else { + defer res.Body.Close() + if ret, err := io.ReadAll(res.Body); err != nil { + return nil, errors.Wrapf(err, "read http respose error") + } else if !isHttpStatusOK(res.StatusCode) { + return ret, errors.Errorf("http response status code %v", res.StatusCode) + } else { + return ret, nil + } + } +} + func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { v := &SRSServer{} for _, opt := range opts { @@ -158,6 +183,8 @@ type SRSLoadBalancer interface { StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error // Load the WebRTC streaming by ufrag, the ICE username. LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) + // proxy http api to srs + ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error } // srsLoadBalancer is the global SRS load balancer. @@ -165,6 +192,7 @@ var srsLoadBalancer SRSLoadBalancer // srsMemoryLoadBalancer stores state in memory. type srsMemoryLoadBalancer struct { + *SrsApiProxy // All available SRS servers, key is server ID. servers sync.Map[string, *SRSServer] // The picked server to servce client by specified stream URL, key is stream url. @@ -287,7 +315,17 @@ func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag str } } +func (v *srsMemoryLoadBalancer) ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + services := make([]*SRSServer, v.servers.Size()) + v.servers.Range(func(_ string, value *SRSServer) bool { + services = append(services, value) + return true + }) + return v.proxySrsAPI(ctx, services, w, r) +} + type srsRedisLoadBalancer struct { + *SrsApiProxy // The redis client sdk. rdb *redis.Client } @@ -528,6 +566,40 @@ func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag stri return &actual, nil } +func (v *srsRedisLoadBalancer) ProxyHTTPAPI(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // No server found, failed. + if len(serverKeys) == 0 { + err := errors.New("servers empty") + apiError(ctx, w, r, err, http.StatusInternalServerError) + return err + } + + // TODO get all SRSServer + var srsServers []*SRSServer + + for _, key := range serverKeys { + var server SRSServer + if b, err := v.rdb.Get(ctx, key).Bytes(); err == nil { + if err := json.Unmarshal(b, &server); err != nil { + return errors.Wrapf(err, "unmarshal servers %v, %v", key, string(b)) + } + srsServers = append(srsServers, &server) + logger.Df(ctx, "srsServer: %v", server) + } + } + + return v.proxySrsAPI(ctx, srsServers, w, r) +} + func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string { return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) } @@ -549,5 +621,5 @@ func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string { } func (v *srsRedisLoadBalancer) redisKeyServers() string { - return fmt.Sprintf("srs-proxy-all-servers") + return "srs-proxy-all-servers" } diff --git a/proxy/sync/map.go b/proxy/sync/map.go index 75db12f9a9..fe35dc91eb 100644 --- a/proxy/sync/map.go +++ b/proxy/sync/map.go @@ -43,3 +43,13 @@ func (m *Map[K, V]) Range(f func(key K, value V) bool) { func (m *Map[K, V]) Store(key K, value V) { m.m.Store(key, value) } + +func (m *Map[K, V]) Size() uint32 { + size := uint32(0) + m.m.Range(func(_, _ any) bool { + size++ + return true + }) + + return size +} diff --git a/proxy/utils.go b/proxy/utils.go index f3c3930762..fd84d6dd0e 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -32,7 +32,7 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da b, err := json.Marshal(data) if err != nil { - apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data)) + apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data), http.StatusInternalServerError) return } @@ -41,10 +41,10 @@ func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, da w.Write(b) } -func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { +func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error, code int) { logger.Wf(ctx, "HTTP API error %+v", err) w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(code) fmt.Fprintln(w, fmt.Sprintf("%v", err)) } @@ -69,6 +69,10 @@ func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { return false } +func isHttpStatusOK(v int) bool { + return v >= 200 && v < 300 +} + func parseGracefullyQuitTimeout() (time.Duration, error) { if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) @@ -250,8 +254,9 @@ func parseSRTStreamID(sid string) (host, resource string, err error) { } // parseListenEndpoint parse the listen endpoint as: -// port The tcp listen port, like 1935. -// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 +// +// port The tcp listen port, like 1935. +// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { // If no colon in ep, it's port in string. if !strings.Contains(ep, ":") {