parens handling for where

This commit is contained in:
Loic Nageleisen 2017-03-01 15:55:38 +01:00
parent 814aa46592
commit e9ee35f68e
2 changed files with 27 additions and 11 deletions

View file

@ -51,6 +51,20 @@ module Rebel::SQL
end end
class Raw < String class Raw < String
def wants_parens!
@wants_parens = true
self
end
def wants_parens?
@wants_parens = false unless instance_variable_defined?(:@wants_parens)
@wants_parens
end
def parens
Raw.new("(#{self})")
end
def as(n) def as(n)
Raw.new(self + " AS #{Rebel::SQL.name(n)}") Raw.new(self + " AS #{Rebel::SQL.name(n)}")
end end
@ -72,7 +86,7 @@ module Rebel::SQL
end end
def or(clause) def or(clause)
Raw.new("#{self} OR #{Rebel::SQL.and_clause(clause)}") Raw.new("#{self} OR #{Rebel::SQL.and_clause(clause)}").wants_parens!
end end
def eq(n) def eq(n)
@ -257,28 +271,25 @@ module Rebel::SQL
def clause_term(left, right) def clause_term(left, right)
case right case right
when Array when Array
"#{name(left)} IN (#{values(*right)})" name(left).in(*right)
else else
"#{name(left)} = #{name_or_value(right)}" name(left).eq(name_or_value(right))
end end
end end
def and_clause(clause) def and_clause(*clause)
return clause if clause.is_a?(Raw) || clause.is_a?(String)
clause.map do |e| clause.map do |e|
case e case e
when Array then clause_term(e[0], e[1]) when Array then clause_term(e[0], e[1])
when Raw, String then e when Raw then e.wants_parens? && clause.count > 1 ? "(#{e})" : e
when String then e
else raise NotImplementedError, e.class else raise NotImplementedError, e.class
end end
end.join(' AND ') end.join(' AND ')
end end
def where?(clause) def where?(*clause)
return "WHERE #{clause}" if clause.is_a?(Raw) || clause.is_a?(String) clause.any? ? "WHERE #{Rebel::SQL.and_clause(*clause)}" : nil
clause && clause.any? ? "WHERE #{Rebel::SQL.and_clause(clause)}" : nil
end end
def inner?(join) def inner?(join)

View file

@ -52,4 +52,9 @@ class TestRaw < Minitest::Test
def test_in def test_in
assert_str_equal(Rebel::SQL.name(:foo).in(1, 2, 3), '"foo" IN (1, 2, 3)') assert_str_equal(Rebel::SQL.name(:foo).in(1, 2, 3), '"foo" IN (1, 2, 3)')
end end
def test_where
assert_str_equal(Rebel::SQL.where?(Rebel::SQL.name(:foo).eq(1).or(Rebel::SQL.name(:bar).eq(2)), Rebel::SQL.name(:baz).eq(3)), 'WHERE ("foo" = 1 OR "bar" = 2) AND "baz" = 3')
assert_str_equal(Rebel::SQL.where?(Rebel::SQL.name(:foo).eq(1).or(Rebel::SQL.name(:bar).eq(2))), 'WHERE "foo" = 1 OR "bar" = 2')
end
end end