diff --git a/duration.go b/duration.go index 68ef526..1723553 100644 --- a/duration.go +++ b/duration.go @@ -2,6 +2,7 @@ package duration import ( "errors" + "fmt" "math" "strconv" "time" @@ -275,3 +276,26 @@ func (duration *Duration) String() string { return d } + +func (duration Duration) MarshalJSON() ([]byte, error) { + return []byte("\"" + duration.String() + "\""), nil +} + +func (duration *Duration) UnmarshalJSON(source []byte) error { + strVal := string(source) + if len(strVal) < 2 { + return fmt.Errorf("invalid ISO 8601 duration: %s", strVal) + } + strVal = strVal[1 : len(strVal)-1] + + if strVal == "null" { + return nil + } + + parsed, err := Parse(strVal) + if err != nil { + return fmt.Errorf("invalid ISO 8601 duration: %s", strVal) + } + *duration = *parsed + return nil +} diff --git a/duration_test.go b/duration_test.go index a664f03..c2b6e0f 100644 --- a/duration_test.go +++ b/duration_test.go @@ -1,6 +1,7 @@ package duration import ( + "encoding/json" "reflect" "testing" "time" @@ -262,3 +263,64 @@ func TestDuration_String(t *testing.T) { t.Errorf("expected: %s, got: %s", "-PT2H5M", negativeDuration.String()) } } + +func TestDuration_MarshalJSON(t *testing.T) { + td, err := Parse("P3Y6M4DT12H30M5.5S") + if err != nil { + t.Fatal(err) + } + + jsonVal, err := json.Marshal(struct { + Dur *Duration `json:"d"` + }{Dur: td}) + if err != nil { + t.Errorf("did not expect error: %s", err.Error()) + } + if string(jsonVal) != `{"d":"P3Y6M4DT12H30M5.5S"}` { + t.Errorf("expected: %s, got: %s", `{"d":"P3Y6M4DT12H30M5.5S"}`, string(jsonVal)) + } + + jsonVal, err = json.Marshal(struct { + Dur Duration `json:"d"` + }{Dur: *td}) + if err != nil { + t.Errorf("did not expect error: %s", err.Error()) + } + if string(jsonVal) != `{"d":"P3Y6M4DT12H30M5.5S"}` { + t.Errorf("expected: %s, got: %s", `{"d":"P3Y6M4DT12H30M5.5S"}`, string(jsonVal)) + } +} + +func TestDuration_UnmarshalJSON(t *testing.T) { + jsonStr := ` + { + "d": "P3Y6M4DT12H30M5.5S" + } + ` + expected, err := Parse("P3Y6M4DT12H30M5.5S") + if err != nil { + t.Fatal(err) + } + + var durStructPtr struct { + Dur *Duration `json:"d"` + } + err = json.Unmarshal([]byte(jsonStr), &durStructPtr) + if err != nil { + t.Errorf("did not expect error: %s", err.Error()) + } + if !reflect.DeepEqual(durStructPtr.Dur, expected) { + t.Errorf("JSON Unmarshal ptr got = %s, want %s", durStructPtr.Dur, expected) + } + + var durStruct struct { + Dur Duration `json:"d"` + } + err = json.Unmarshal([]byte(jsonStr), &durStruct) + if err != nil { + t.Errorf("did not expect error: %s", err.Error()) + } + if !reflect.DeepEqual(durStruct.Dur, *expected) { + t.Errorf("JSON Unmarshal ptr got = %s, want %s", &(durStruct.Dur), expected) + } +}