From e9ee35f68ed70daef3573bba9c4855297e88c44b Mon Sep 17 00:00:00 2001 From: Loic Nageleisen Date: Wed, 1 Mar 2017 15:55:38 +0100 Subject: [PATCH] parens handling for where --- lib/rebel/sql.rb | 33 ++++++++++++++++++++++----------- test/test_raw.rb | 5 +++++ 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/lib/rebel/sql.rb b/lib/rebel/sql.rb index 99fb0c4..a5426fd 100644 --- a/lib/rebel/sql.rb +++ b/lib/rebel/sql.rb @@ -51,6 +51,20 @@ module Rebel::SQL end class Raw < String + def wants_parens! + @wants_parens = true + self + end + + def wants_parens? + @wants_parens = false unless instance_variable_defined?(:@wants_parens) + @wants_parens + end + + def parens + Raw.new("(#{self})") + end + def as(n) Raw.new(self + " AS #{Rebel::SQL.name(n)}") end @@ -72,7 +86,7 @@ module Rebel::SQL end def or(clause) - Raw.new("#{self} OR #{Rebel::SQL.and_clause(clause)}") + Raw.new("#{self} OR #{Rebel::SQL.and_clause(clause)}").wants_parens! end def eq(n) @@ -257,28 +271,25 @@ module Rebel::SQL def clause_term(left, right) case right when Array - "#{name(left)} IN (#{values(*right)})" + name(left).in(*right) else - "#{name(left)} = #{name_or_value(right)}" + name(left).eq(name_or_value(right)) end end - def and_clause(clause) - return clause if clause.is_a?(Raw) || clause.is_a?(String) - + def and_clause(*clause) clause.map do |e| case e when Array then clause_term(e[0], e[1]) - when Raw, String then e + when Raw then e.wants_parens? && clause.count > 1 ? "(#{e})" : e + when String then e else raise NotImplementedError, e.class end end.join(' AND ') end - def where?(clause) - return "WHERE #{clause}" if clause.is_a?(Raw) || clause.is_a?(String) - - clause && clause.any? ? "WHERE #{Rebel::SQL.and_clause(clause)}" : nil + def where?(*clause) + clause.any? ? "WHERE #{Rebel::SQL.and_clause(*clause)}" : nil end def inner?(join) diff --git a/test/test_raw.rb b/test/test_raw.rb index 4a0e0f2..fbc8552 100644 --- a/test/test_raw.rb +++ b/test/test_raw.rb @@ -52,4 +52,9 @@ class TestRaw < Minitest::Test def test_in assert_str_equal(Rebel::SQL.name(:foo).in(1, 2, 3), '"foo" IN (1, 2, 3)') end + + def test_where + assert_str_equal(Rebel::SQL.where?(Rebel::SQL.name(:foo).eq(1).or(Rebel::SQL.name(:bar).eq(2)), Rebel::SQL.name(:baz).eq(3)), 'WHERE ("foo" = 1 OR "bar" = 2) AND "baz" = 3') + assert_str_equal(Rebel::SQL.where?(Rebel::SQL.name(:foo).eq(1).or(Rebel::SQL.name(:bar).eq(2))), 'WHERE "foo" = 1 OR "bar" = 2') + end end