Skip to content

Commit

Permalink
improve get protocol method for https
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonTian committed Apr 12, 2024
1 parent 081974d commit c085e8a
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 42 deletions.
16 changes: 12 additions & 4 deletions meta/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,20 @@ func (a *Api) GetMethod() string {
}

func (a *Api) GetProtocol() string {
protocol := strings.ToLower(a.Protocol)
if strings.HasPrefix(protocol, "https") {
lowered := strings.ToLower(a.Protocol)

if strings.HasPrefix(lowered, "https") {
return "https"
} else {
return "http"
}

parts := strings.Split(lowered, "|")
for _, v := range parts {
if v == "https" {
return "https"
}
}

return "http"
}

func (a *Api) FindParameter(name string) *Parameter {
Expand Down
18 changes: 17 additions & 1 deletion meta/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
Expand Down Expand Up @@ -45,6 +45,22 @@ func TestApi_GetProtocol(t *testing.T) {
api.Protocol = "http://"
protocol = api.GetProtocol()
assert.Equal(t, protocol, "http")

api.Protocol = "HTTP"
protocol = api.GetProtocol()
assert.Equal(t, protocol, "http")

api.Protocol = "HTTPS"
protocol = api.GetProtocol()
assert.Equal(t, protocol, "https")

api.Protocol = "HTTP|HTTPS"
protocol = api.GetProtocol()
assert.Equal(t, protocol, "https")

api.Protocol = "HTTPS|HTTP"
protocol = api.GetProtocol()
assert.Equal(t, protocol, "https")
}

func TestApi_FindParameter(t *testing.T) {
Expand Down
81 changes: 60 additions & 21 deletions openapi/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

func AddFlags(fs *cli.FlagSet) {
fs.Add(NewSecureFlag())
fs.Add(NewInsecureFlag())
fs.Add(NewForceFlag())
fs.Add(NewEndpointFlag())
fs.Add(NewVersionFlag())
Expand All @@ -38,6 +39,7 @@ func AddFlags(fs *cli.FlagSet) {

const (
SecureFlagName = "secure"
InsecureFlagName = "insecure"
ForceFlagName = "force"
EndpointFlagName = "endpoint"
VersionFlagName = "version"
Expand All @@ -60,6 +62,10 @@ func SecureFlag(fs *cli.FlagSet) *cli.Flag {
return fs.Get(SecureFlagName)
}

func InsecureFlag(fs *cli.FlagSet) *cli.Flag {
return fs.Get(InsecureFlagName)
}

func ForceFlag(fs *cli.FlagSet) *cli.Flag {
return fs.Get(ForceFlagName)
}
Expand Down Expand Up @@ -116,72 +122,102 @@ func MethodFlag(fs *cli.FlagSet) *cli.Flag {
//}

func NewSecureFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: SecureFlagName, AssignedMode: cli.AssignedNone,
return &cli.Flag{
Category: "caller",
Name: SecureFlagName,
AssignedMode: cli.AssignedNone,
Short: i18n.T(
"use `--secure` to force https",
"使用 `--secure` 开关强制使用https方式调用")}
}

func NewInsecureFlag() *cli.Flag {
return &cli.Flag{
Category: "caller",
Name: InsecureFlagName,
AssignedMode: cli.AssignedNone,
Hidden: true,
Short: i18n.T(
"use `--insecure` to force http(not recommend)",
"使用 `--insecure` 开关强制使用http方式调用(不推荐)")}
}

func NewForceFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: ForceFlagName, AssignedMode: cli.AssignedNone,
return &cli.Flag{
Category: "caller",
Name: ForceFlagName,
AssignedMode: cli.AssignedNone,
Short: i18n.T(
"use `--force` to skip api and parameters check",
"添加 `--force` 开关可跳过API与参数的合法性检查")}
}

func NewEndpointFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: EndpointFlagName, AssignedMode: cli.AssignedOnce,
return &cli.Flag{
Category: "caller",
Name: EndpointFlagName,
AssignedMode: cli.AssignedOnce,
Short: i18n.T(
"use `--endpoint <endpoint>` to assign endpoint",
"使用 `--endpoint <endpoint>` 来指定接入点地址")}
}

func NewVersionFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: VersionFlagName, AssignedMode: cli.AssignedOnce,
return &cli.Flag{
Category: "caller",
Name: VersionFlagName,
AssignedMode: cli.AssignedOnce,
Short: i18n.T(
"use `--version <YYYY-MM-DD>` to assign product api version",
"使用 `--version <YYYY-MM-DD>` 来指定访问的API版本")}
}

func NewHeaderFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: HeaderFlagName, AssignedMode: cli.AssignedRepeatable,
return &cli.Flag{
Category: "caller",
Name: HeaderFlagName, AssignedMode: cli.AssignedRepeatable,
Short: i18n.T(
"use `--header X-foo=bar` to add custom HTTP header, repeatable",
"使用 `--header X-foo=bar` 来添加特定的HTTP头, 可多次添加")}
}

func NewBodyFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: BodyFlagName, AssignedMode: cli.AssignedOnce,
return &cli.Flag{
Category: "caller",
Name: BodyFlagName, AssignedMode: cli.AssignedOnce,
Short: i18n.T(
"use `--body $(cat foo.json)` to assign http body in RESTful call",
"使用 `--body $(cat foo.json)` 来指定在RESTful调用中的HTTP包体")}
}

func NewBodyFileFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: BodyFileFlagName, AssignedMode: cli.AssignedOnce, Hidden: true,
return &cli.Flag{
Category: "caller",
Name: BodyFileFlagName,
AssignedMode: cli.AssignedOnce,
Hidden: true,
Short: i18n.T(
"assign http body in Restful call with local file",
"使用 `--body-file foo.json` 来指定输入包体")}
}

func NewAcceptFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: AcceptFlagName, AssignedMode: cli.AssignedOnce, Hidden: true,
return &cli.Flag{
Category: "caller",
Name: AcceptFlagName,
AssignedMode: cli.AssignedOnce,
Hidden: true,
Short: i18n.T(
"add `--accept {json|xml}` to add Accept header",
"使用 `--accept {json|xml}` 来指定Accept头")}
}

func NewRoaFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
Name: RoaFlagName, AssignedMode: cli.AssignedOnce, Hidden: true,
return &cli.Flag{
Category: "caller",
Name: RoaFlagName,
AssignedMode: cli.AssignedOnce,
Hidden: true,
Short: i18n.T(
"use `--roa {GET|PUT|POST|DELETE}` to assign restful call.[DEPRECATED]",
"使用 `--roa {GET|PUT|POST|DELETE}` 使用restful方式调用[已过期]",
Expand All @@ -190,7 +226,8 @@ func NewRoaFlag() *cli.Flag {
}

func NewDryRunFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
return &cli.Flag{
Category: "caller",
Name: DryRunFlagName,
AssignedMode: cli.AssignedNone,
Short: i18n.T(
Expand All @@ -202,7 +239,8 @@ func NewDryRunFlag() *cli.Flag {
}

func NewQuietFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
return &cli.Flag{
Category: "caller",
Name: QuietFlagName,
Shorthand: 'q',
AssignedMode: cli.AssignedNone,
Expand All @@ -215,7 +253,8 @@ func NewQuietFlag() *cli.Flag {
}

func NewMethodFlag() *cli.Flag {
return &cli.Flag{Category: "caller",
return &cli.Flag{
Category: "caller",
Name: MethodFlagName,
AssignedMode: cli.AssignedOnce,
Short: i18n.T(
Expand Down
22 changes: 16 additions & 6 deletions openapi/force_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,41 @@ type ForceRpcInvoker struct {
method string
}

func (a *ForceRpcInvoker) Prepare(ctx *cli.Context) error {
func (a *ForceRpcInvoker) Prepare(ctx *cli.Context) (err error) {
// assign api name
a.request.ApiName = a.method
// default to use https
a.request.Scheme = "https"

// assign parameters
for _, f := range ctx.UnknownFlags().Flags() {
a.request.QueryParams[f.Name], _ = f.GetValue()
}

// --insecure use http
if _, ok := InsecureFlag(ctx.Flags()).GetValue(); ok {
a.request.Scheme = "http"
}

// --secure use https
if _, ok := SecureFlag(ctx.Flags()).GetValue(); ok {
a.request.Scheme = "https"
}

// if '--method' assigned, reset method
if method, ok := MethodFlag(ctx.Flags()).GetValue(); ok {
if method == "GET" || method == "POST" {
a.request.Method = method
} else {
return fmt.Errorf("--method value %s is not supported, please set method in {GET|POST}", method)
err = fmt.Errorf("--method value %s is not supported, please set method in {GET|POST}", method)
return
}
}
return nil

return
}

func (a *ForceRpcInvoker) Call() (*responses.CommonResponse, error) {
resp, err := a.client.ProcessCommonRequest(a.request)
return resp, err
func (a *ForceRpcInvoker) Call() (resp *responses.CommonResponse, err error) {
resp, err = a.client.ProcessCommonRequest(a.request)
return
}
1 change: 1 addition & 0 deletions openapi/force_rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func TestForceRpcInvoker_Prepare(t *testing.T) {
methodflag.SetAssigned(true)
methodflag.SetValue("POST")
ctx.Flags().Add(secureflag)
ctx.Flags().Add(NewInsecureFlag())
ctx.Flags().Add(methodflag)
ctx.UnknownFlags().Add(NewSecureFlag())
err := a.Prepare(ctx)
Expand Down
14 changes: 5 additions & 9 deletions openapi/library_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,29 @@ func TestLibrary_PrintProducts(t *testing.T) {
func TestLibrary_PrintProductUsage(t *testing.T) {
w := new(bytes.Buffer)
library := NewLibrary(w, "en")
content := `{"products":[{"code":"ecs","api_style":"rpc","apis":["DescribeRegions"]}]}`
library.builtinRepo = getRepository(content)
library.builtinRepo = getRepository()
err := library.PrintProductUsage("aos", true)
assert.Equal(t, "'aos' is not a valid command or product. See `aliyun help`.", err.Error())

err = library.PrintProductUsage("ecs", true)
assert.Nil(t, err)

content = `{"products":[{"code":"ecs","api_style":"restful","apis":["DescribeRegions"]}]}`
library.builtinRepo = getRepository(content)
library.builtinRepo = getRepository()
err = library.PrintProductUsage("ecs", true)
assert.Nil(t, err)
}

func TestLibrary_PrintApiUsage(t *testing.T) {
w := new(bytes.Buffer)
library := NewLibrary(w, "en")
content := `{"products":[{"code":"ecs","api_style":"rpc","apis":["DescribeRegions"]}]}`
library.builtinRepo = getRepository(content)
library.builtinRepo = getRepository()
err := library.PrintApiUsage("aos", "DescribeRegions")
assert.Equal(t, "'aos' is not a valid command or product. See `aliyun help`.", err.Error())

err = library.PrintApiUsage("ecs", "DescribeRegions")
assert.Nil(t, err)

content = `{"products":[{"code":"ecs","api_style":"restful","apis":["DescribeRegions"]}]}`
library.builtinRepo = getRepository(content)
library.builtinRepo = getRepository()
err = library.PrintApiUsage("ecs", "DescribeRegions")
assert.Nil(t, err)
}
Expand Down Expand Up @@ -101,7 +97,7 @@ func Test_printParameters(t *testing.T) {
printParameters(w, params, "", &newmeta.APIDetail{})
}

func getRepository(content string) *meta.Repository {
func getRepository() *meta.Repository {
repository := meta.LoadRepository()
return repository
}
8 changes: 8 additions & 0 deletions openapi/restful.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func (a *RestfulInvoker) Prepare(ctx *cli.Context) error {
for _, f := range ctx.UnknownFlags().Flags() {
a.request.QueryParams[f.Name], _ = f.GetValue()
}
// default to https
a.request.Scheme = "https"
} else {
for _, f := range ctx.UnknownFlags().Flags() {
param := a.api.FindParameter(f.Name)
Expand All @@ -81,6 +83,12 @@ func (a *RestfulInvoker) Prepare(ctx *cli.Context) error {
return fmt.Errorf("unknown parameter position; %s is %s", param.Name, param.Position)
}
}

a.request.Scheme = a.api.GetProtocol()
}

if _, ok := InsecureFlag(ctx.Flags()).GetValue(); ok {
a.request.Scheme = "http"
}

if _, ok := SecureFlag(ctx.Flags()).GetValue(); ok {
Expand Down
2 changes: 1 addition & 1 deletion openapi/restful_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
Expand Down
5 changes: 5 additions & 0 deletions openapi/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ func (a *RpcInvoker) Prepare(ctx *cli.Context) error {
request.Scheme = api.GetProtocol()
request.Method = api.GetMethod()

// if `--insecure` assigned, use https
if _, ok := InsecureFlag(ctx.Flags()).GetValue(); ok {
a.request.Scheme = "https"
}

// if `--secure` assigned, use https
if _, ok := SecureFlag(ctx.Flags()).GetValue(); ok {
a.request.Scheme = "https"
Expand Down
1 change: 1 addition & 0 deletions openapi/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func TestRpcInvoker_Prepare(t *testing.T) {
secureflag := NewSecureFlag()
secureflag.SetAssigned(true)
ctx.Flags().Add(secureflag)
ctx.Flags().Add(NewInsecureFlag())
methodflag := NewMethodFlag()
methodflag.SetAssigned(true)
methodflag.SetValue("POST")
Expand Down

0 comments on commit c085e8a

Please sign in to comment.