diff --git a/lib/redis/distributed.rb b/lib/redis/distributed.rb index b89a36d93..0ffaa2bd2 100644 --- a/lib/redis/distributed.rb +++ b/lib/redis/distributed.rb @@ -24,10 +24,14 @@ def initialize(node_configs, options = {}) @default_options = options.dup node_configs.each { |node_config| add_node(node_config) } @subscribed_node = nil + @watch_key = nil end def node_for(key) - @ring.get_node(key_tag(key.to_s) || key.to_s) + key = key_tag(key.to_s) || key.to_s + raise CannotDistribute, :watch if @watch_key && @watch_key != key + + @ring.get_node(key) end def nodes @@ -799,13 +803,26 @@ def punsubscribe(*channels) end # Watch the given keys to determine execution of the MULTI/EXEC block. - def watch(*_keys) - raise CannotDistribute, :watch + def watch(*keys, &block) + ensure_same_node(:watch, keys) do |node| + @watch_key = key_tag(keys.first) || keys.first.to_s + + begin + node.watch(*keys, &block) + rescue StandardError + @watch_key = nil + raise + end + end end # Forget about all watched keys. def unwatch - raise CannotDistribute, :unwatch + raise CannotDistribute, :unwatch unless @watch_key + + result = node_for(@watch_key).unwatch + @watch_key = nil + result end def pipelined @@ -813,18 +830,30 @@ def pipelined end # Mark the start of a transaction block. - def multi - raise CannotDistribute, :multi + def multi(&block) + raise CannotDistribute, :multi unless @watch_key + + result = node_for(@watch_key).multi(&block) + @watch_key = nil if block_given? + result end # Execute all commands issued after MULTI. def exec - raise CannotDistribute, :exec + raise CannotDistribute, :exec unless @watch_key + + result = node_for(@watch_key).exec + @watch_key = nil + result end # Discard all commands issued after MULTI. def discard - raise CannotDistribute, :discard + raise CannotDistribute, :discard unless @watch_key + + result = node_for(@watch_key).discard + @watch_key = nil + result end # Control remote script registry. diff --git a/test/distributed_transactions_test.rb b/test/distributed_transactions_test.rb index 73430441c..e1c77de9e 100644 --- a/test/distributed_transactions_test.rb +++ b/test/distributed_transactions_test.rb @@ -6,6 +6,22 @@ class TestDistributedTransactions < Minitest::Test include Helper::Distributed def test_multi_discard + r.set("foo", 1) + + r.watch("foo") + r.multi + r.set("foo", 2) + + assert_raises Redis::Distributed::CannotDistribute do + r.set("bar", 1) + end + + r.discard + + assert_equal('1', r.get("foo")) + end + + def test_multi_discard_without_watch @foo = nil assert_raises Redis::Distributed::CannotDistribute do @@ -19,13 +35,70 @@ def test_multi_discard end end - def test_watch_unwatch + def test_watch_unwatch_without_clustering assert_raises Redis::Distributed::CannotDistribute do - r.watch("foo") + r.watch("foo", "bar") + end + + r.watch("{qux}foo", "{qux}bar") do + assert_raises Redis::Distributed::CannotDistribute do + r.get("{baz}foo") + end + + r.unwatch end assert_raises Redis::Distributed::CannotDistribute do r.unwatch end end + + def test_watch_with_exception + assert_raises StandardError do + r.watch("{qux}foo", "{qux}bar") do + raise StandardError, "woops" + end + end + + assert_equal "OK", r.set("{other}baz", 1) + end + + def test_watch_unwatch + assert_equal "OK", r.watch("{qux}foo", "{qux}bar") + assert_equal "OK", r.unwatch + end + + def test_watch_multi_with_block + r.set("{qux}baz", 1) + + r.watch("{qux}foo", "{qux}bar", "{qux}baz") do + assert_equal '1', r.get("{qux}baz") + + result = r.multi do + r.incrby("{qux}foo", 3) + r.incrby("{qux}bar", 6) + r.incrby("{qux}baz", 9) + end + + assert_equal [3, 6, 10], result + end + end + + def test_watch_multi_exec_without_block + r.set("{qux}baz", 1) + + assert_equal "OK", r.watch("{qux}foo", "{qux}bar", "{qux}baz") + assert_equal '1', r.get("{qux}baz") + + assert_raises Redis::Distributed::CannotDistribute do + r.get("{foo}baz") + end + + assert_equal "OK", r.multi + assert_equal "QUEUED", r.incrby("{qux}baz", 1) + assert_equal "QUEUED", r.incrby("{qux}baz", 1) + assert_equal [2, 3], r.exec + + assert_equal "OK", r.set("{other}baz", 1) + end end