diff --git a/level.go b/level.go index 98dd191..7e6e77b 100644 --- a/level.go +++ b/level.go @@ -8,6 +8,7 @@ import ( "errors" "strings" "sync" + "sync/atomic" ) // ErrInvalidLogLevel is used when an invalid log level has been used. @@ -65,9 +66,11 @@ type LeveledBackend interface { Leveled } +type moduleLeveledMap map[string]Level type moduleLeveled struct { - levels map[string]Level - backend Backend + lock sync.Mutex + levels atomic.Value + backend atomic.Value formatter Formatter once sync.Once } @@ -78,19 +81,20 @@ func AddModuleLevel(backend Backend) LeveledBackend { var leveled LeveledBackend var ok bool if leveled, ok = backend.(LeveledBackend); !ok { - leveled = &moduleLeveled{ - levels: make(map[string]Level), - backend: backend, - } + modLeveled := &moduleLeveled{} + modLeveled.levels.Store(make(moduleLeveledMap)) + modLeveled.backend.Store(backend) + leveled = modLeveled } return leveled } // GetLevel returns the log level for the given module. func (l *moduleLeveled) GetLevel(module string) Level { - level, exists := l.levels[module] + levels := l.levels.Load().(moduleLeveledMap) + level, exists := levels[module] if exists == false { - level, exists = l.levels[""] + level, exists = levels[""] // no configuration exists, default to debug if exists == false { level = DEBUG @@ -101,7 +105,10 @@ func (l *moduleLeveled) GetLevel(module string) Level { // SetLevel sets the log level for the given module. func (l *moduleLeveled) SetLevel(level Level, module string) { - l.levels[module] = level + levels := l.levels.Load().(moduleLeveledMap) + l.lock.Lock() + levels[module] = level + l.lock.Unlock() } // IsEnabledFor will return true if logging is enabled for the given module. @@ -113,7 +120,7 @@ func (l *moduleLeveled) Log(level Level, calldepth int, rec *Record) (err error) if l.IsEnabledFor(level, rec.Module) { // TODO get rid of traces of formatter here. BackendFormatter should be used. rec.formatter = l.getFormatterAndCacheCurrent() - err = l.backend.Log(level, calldepth+1, rec) + err = l.backend.Load().(Backend).Log(level, calldepth+1, rec) } return } diff --git a/level_test.go b/level_test.go index c8f9a37..beea038 100644 --- a/level_test.go +++ b/level_test.go @@ -74,3 +74,48 @@ func TestLevelModuleLevel(t *testing.T) { } } } + +func testConcurrent_SetLevel(i int, sync *syncTestConcurrent, backend interface{}) { + sync.start.Done() + sync.start.Wait() + for j := 0; j < 1000; j++ { + leveled := backend.(LeveledBackend) + leveled.SetLevel(NOTICE, "") + leveled.SetLevel(ERROR, "foo") + leveled.SetLevel(INFO, "foo.bar") + leveled.SetLevel(WARNING, "bar") + } + sync.end.Done() +} + +func TestLevelModuleLevel_Concurency(t *testing.T) { + backend := NewMemoryBackend(128) + + leveled := AddModuleLevel(backend) + + sync := &syncTestConcurrent{} + sync.end.Add(10) + sync.start.Add(10) + for i := 0; i < 10; i++ { + go testConcurrent_SetLevel(i, sync, leveled) + } + sync.end.Wait() + + expected := []struct { + level Level + module string + }{ + {NOTICE, ""}, + {NOTICE, "something"}, + {ERROR, "foo"}, + {INFO, "foo.bar"}, + {WARNING, "bar"}, + } + + for _, e := range expected { + actual := leveled.GetLevel(e.module) + if e.level != actual { + t.Errorf("unexpected level in %s: %s != %s", e.module, e.level, actual) + } + } +} diff --git a/log_test.go b/log_test.go index c7a645f..abb00e7 100644 --- a/log_test.go +++ b/log_test.go @@ -9,9 +9,14 @@ import ( "io/ioutil" "log" "strings" + "sync" "testing" ) +type syncTestConcurrent struct { + start, end sync.WaitGroup +} + func TestLogCalldepth(t *testing.T) { buf := &bytes.Buffer{} SetBackend(NewLogBackend(buf, "", log.Lshortfile)) diff --git a/logger.go b/logger.go index 535ed9b..9d97b45 100644 --- a/logger.go +++ b/logger.go @@ -91,9 +91,8 @@ func (r *Record) Message() string { // Logger is the actual logger which creates log records based on the functions // called and passes them to the underlying logging backend. type Logger struct { - Module string - backend LeveledBackend - haveBackend bool + Module string + backend atomic.Value // ExtraCallDepth can be used to add additional call depth when getting the // calling function. This is normally used when wrapping a logger. @@ -102,8 +101,7 @@ type Logger struct { // SetBackend overrides any previously defined backend for this logger. func (l *Logger) SetBackend(backend LeveledBackend) { - l.backend = backend - l.haveBackend = true + l.backend.Store(backend) } // TODO call NewLogger and remove MustGetLogger? @@ -162,8 +160,8 @@ func (l *Logger) log(lvl Level, format *string, args ...interface{}) { // methods, Info(), Fatal(), etc. // ExtraCallDepth allows this to be extended further up the stack in case we // are wrapping these methods, eg. to expose them package level - if l.haveBackend { - l.backend.Log(lvl, 2+l.ExtraCalldepth, record) + if l.backend.Load() != nil { + l.backend.Load().(LeveledBackend).Log(lvl, 2+l.ExtraCalldepth, record) return } diff --git a/logger_test.go b/logger_test.go index b9f7fe7..35d89b7 100644 --- a/logger_test.go +++ b/logger_test.go @@ -56,7 +56,45 @@ func TestPrivateBackend(t *testing.T) { if stdBackend.size > 0 { t.Errorf("something in stdBackend, size of backend: %d", stdBackend.size) } - if "to private baсkend" == MemoryRecordN(privateBackend, 0).Formatted(0) { - t.Error("logged to defaultBackend:", MemoryRecordN(privateBackend, 0)) + if privateBackend.size != 1 { + t.Errorf("privateBackend must contain something, size of backend: %d", privateBackend.size) + } + if "to private backend" != MemoryRecordN(privateBackend, 0).Formatted(0) { + t.Error("must be logged to privateBackend:", MemoryRecordN(privateBackend, 0)) + } +} + +func testConcurrent_Log(i int, sync *syncTestConcurrent, lvlBackend *LeveledBackend, log *Logger) { + sync.start.Done() + sync.start.Wait() + for j := 0; j < 1000; j++ { + log.SetBackend(*lvlBackend) + log.Debug("to private backend") + } + sync.end.Done() +} + +func TestPrivateBackend_Concurency(t *testing.T) { + stdBackend := InitForTesting(DEBUG) + log := MustGetLogger("test") + privateBackend := NewMemoryBackend(10240) + lvlBackend := AddModuleLevel(privateBackend) + + sync := &syncTestConcurrent{} + sync.end.Add(10) + sync.start.Add(10) + for i := 0; i < 10; i++ { + go testConcurrent_Log(i, sync, &lvlBackend, log) + } + sync.end.Wait() + + if stdBackend.size > 0 { + t.Errorf("something in stdBackend, size of backend: %d", stdBackend.size) + } + if privateBackend.size != 10*1000 { + t.Errorf("privateBackend must contain something, size of backend: %d", privateBackend.size) + } + if "to private backend" != MemoryRecordN(privateBackend, 0).Formatted(0) { + t.Error("must be logged to privateBackend:", MemoryRecordN(privateBackend, 0)) } } diff --git a/memory.go b/memory.go index 8d5152c..7f55f2b 100644 --- a/memory.go +++ b/memory.go @@ -68,21 +68,17 @@ func (b *MemoryBackend) Log(level Level, calldepth int, rec *Record) error { // head will both be nil. When we successfully set the tail and the previous // value was nil, it's safe to set the head to the current value too. for { - tailp := b.tail - swapped := atomic.CompareAndSwapPointer( + tailp := atomic.SwapPointer( &b.tail, - tailp, np, ) - if swapped == true { - if tailp == nil { - b.head = np - } else { - (*node)(tailp).next = n - } - size = atomic.AddInt32(&b.size, 1) - break + if tailp == nil { + b.head = np + } else { + (*node)(tailp).next = n } + size = atomic.AddInt32(&b.size, 1) + break } // Since one record was added, we might have overflowed the list. Remove