From 35a13c08dec874cf3bd67e9bbc79953243aa8274 Mon Sep 17 00:00:00 2001 From: Mark Nuzzolilo Date: Mon, 14 Oct 2024 20:43:18 -0700 Subject: [PATCH] Add hash return type for relation groupchain size --- lib/tapioca/dsl/compilers/active_record_relations.rb | 6 +++--- .../dsl/compilers/active_record_relations_spec.rb | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/tapioca/dsl/compilers/active_record_relations.rb b/lib/tapioca/dsl/compilers/active_record_relations.rb index f42345297..20aa676a7 100644 --- a/lib/tapioca/dsl/compilers/active_record_relations.rb +++ b/lib/tapioca/dsl/compilers/active_record_relations.rb @@ -372,7 +372,7 @@ def create_group_chain_methods(klass) return_type: "T.self_type", ) - CALCULATION_METHODS.each do |method_name| + (CALCULATION_METHODS + [:size]).each do |method_name| case method_name when :average, :maximum, :minimum klass.create_method( @@ -400,9 +400,9 @@ def create_group_chain_methods(klass) ], return_type: "T::Hash[T.untyped, Integer]", ) - when :sum + when :sum, :size klass.create_method( - "sum", + method_name.to_s, parameters: [ create_opt_param("column_name", type: "T.nilable(T.any(String, Symbol))", default: "nil"), create_block_param("block", type: "T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))"), diff --git a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb index e5defa1cb..d12e24ec6 100644 --- a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb +++ b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb @@ -658,6 +658,9 @@ def maximum(column_name); end sig { params(column_name: T.any(String, Symbol)).returns(T::Hash[T.untyped, T.untyped]) } def minimum(column_name); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } + def size(column_name = nil, &block); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } def sum(column_name = nil, &block); end end @@ -760,6 +763,9 @@ def maximum(column_name); end sig { params(column_name: T.any(String, Symbol)).returns(T::Hash[T.untyped, T.untyped]) } def minimum(column_name); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } + def size(column_name = nil, &block); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } def sum(column_name = nil, &block); end end @@ -1376,6 +1382,9 @@ def maximum(column_name); end sig { params(column_name: T.any(String, Symbol)).returns(T::Hash[T.untyped, T.untyped]) } def minimum(column_name); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } + def size(column_name = nil, &block); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } def sum(column_name = nil, &block); end end @@ -1478,6 +1487,9 @@ def maximum(column_name); end sig { params(column_name: T.any(String, Symbol)).returns(T::Hash[T.untyped, T.untyped]) } def minimum(column_name); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } + def size(column_name = nil, &block); end + sig { params(column_name: T.nilable(T.any(String, Symbol)), block: T.nilable(T.proc.params(record: T.untyped).returns(T.untyped))).returns(T::Hash[T.untyped, T.any(Integer, Float, BigDecimal)]) } def sum(column_name = nil, &block); end end