Skip to content

Commit 6631407

Browse files
committed
Merge pull request pandas-dev#7910 from mortada/nth_values
added support for selecting multiple nth values
2 parents 09a2415 + 31ec4e4 commit 6631407

File tree

3 files changed

+71
-12
lines changed

3 files changed

+71
-12
lines changed

doc/source/groupby.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ This shows the first or last n rows from each group.
869869
Taking the nth row of each group
870870
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
871871

872-
To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group:
872+
To select from a DataFrame or Series the nth item, use the nth method. This is a reduction method, and will return a single row (or no row) per group if you pass an int for n:
873873

874874
.. ipython:: python
875875
@@ -880,7 +880,7 @@ To select from a DataFrame or Series the nth item, use the nth method. This is a
880880
g.nth(-1)
881881
g.nth(1)
882882
883-
If you want to select the nth not-null method, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy.
883+
If you want to select the nth not-null item, use the ``dropna`` kwarg. For a DataFrame this should be either ``'any'`` or ``'all'`` just like you would pass to dropna, for a Series this just needs to be truthy.
884884

885885
.. ipython:: python
886886
@@ -904,6 +904,15 @@ As with other methods, passing ``as_index=False``, will achieve a filtration, wh
904904
g.nth(0)
905905
g.nth(-1)
906906
907+
You can also select multiple rows from each group by specifying multiple nth values as a list of ints.
908+
909+
.. ipython:: python
910+
911+
business_dates = date_range(start='4/1/2014', end='6/30/2014', freq='B')
912+
df = DataFrame(1, index=business_dates, columns=['a', 'b'])
913+
# get the first, 4th, and last date index for each month
914+
df.groupby((df.index.year, df.index.month)).nth([0, 3, -1])
915+
907916
Enumerate group items
908917
~~~~~~~~~~~~~~~~~~~~~
909918

pandas/core/groupby.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -782,12 +782,21 @@ def ohlc(self):
782782

783783
def nth(self, n, dropna=None):
784784
"""
785-
Take the nth row from each group.
785+
Take the nth row from each group if n is an int, or a subset of rows
786+
if n is a list of ints.
786787
787-
If dropna, will not show nth non-null row, dropna is either
788+
If dropna, will take the nth non-null row, dropna is either
788789
Truthy (if a Series) or 'all', 'any' (if a DataFrame); this is equivalent
789790
to calling dropna(how=dropna) before the groupby.
790791
792+
Parameters
793+
----------
794+
n : int or list of ints
795+
a single nth value for the row or a list of nth values
796+
dropna : None or str, optional
797+
apply the specified dropna operation before counting which row is
798+
the nth row. Needs to be None, 'any' or 'all'
799+
791800
Examples
792801
--------
793802
>>> df = DataFrame([[1, np.nan], [1, 4], [5, 6]], columns=['A', 'B'])
@@ -815,19 +824,36 @@ def nth(self, n, dropna=None):
815824
5 NaN
816825
817826
"""
827+
if isinstance(n, int):
828+
nth_values = [n]
829+
elif isinstance(n, (set, list, tuple)):
830+
nth_values = list(set(n))
831+
if dropna is not None:
832+
raise ValueError("dropna option with a list of nth values is not supported")
833+
else:
834+
raise TypeError("n needs to be an int or a list/set/tuple of ints")
835+
836+
m = self.grouper._max_groupsize
837+
# filter out values that are outside [-m, m)
838+
pos_nth_values = [i for i in nth_values if i >= 0 and i < m]
839+
neg_nth_values = [i for i in nth_values if i < 0 and i >= -m]
818840

819841
self._set_selection_from_grouper()
820842
if not dropna: # good choice
821-
m = self.grouper._max_groupsize
822-
if n >= m or n < -m:
843+
if not pos_nth_values and not neg_nth_values:
844+
# no valid nth values
823845
return self._selected_obj.loc[[]]
846+
824847
rng = np.zeros(m, dtype=bool)
825-
if n >= 0:
826-
rng[n] = True
827-
is_nth = self._cumcount_array(rng)
828-
else:
829-
rng[- n - 1] = True
830-
is_nth = self._cumcount_array(rng, ascending=False)
848+
for i in pos_nth_values:
849+
rng[i] = True
850+
is_nth = self._cumcount_array(rng)
851+
852+
if neg_nth_values:
853+
rng = np.zeros(m, dtype=bool)
854+
for i in neg_nth_values:
855+
rng[- i - 1] = True
856+
is_nth |= self._cumcount_array(rng, ascending=False)
831857

832858
result = self._selected_obj[is_nth]
833859

pandas/tests/test_groupby.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,30 @@ def test_nth(self):
313313
expected = g.B.first()
314314
assert_series_equal(result,expected)
315315

316+
# test multiple nth values
317+
df = DataFrame([[1, np.nan], [1, 3], [1, 4], [5, 6], [5, 7]],
318+
columns=['A', 'B'])
319+
g = df.groupby('A')
320+
321+
assert_frame_equal(g.nth(0), df.iloc[[0, 3]].set_index('A'))
322+
assert_frame_equal(g.nth([0]), df.iloc[[0, 3]].set_index('A'))
323+
assert_frame_equal(g.nth([0, 1]), df.iloc[[0, 1, 3, 4]].set_index('A'))
324+
assert_frame_equal(g.nth([0, -1]), df.iloc[[0, 2, 3, 4]].set_index('A'))
325+
assert_frame_equal(g.nth([0, 1, 2]), df.iloc[[0, 1, 2, 3, 4]].set_index('A'))
326+
assert_frame_equal(g.nth([0, 1, -1]), df.iloc[[0, 1, 2, 3, 4]].set_index('A'))
327+
assert_frame_equal(g.nth([2]), df.iloc[[2]].set_index('A'))
328+
assert_frame_equal(g.nth([3, 4]), df.loc[[],['B']])
329+
330+
business_dates = pd.date_range(start='4/1/2014', end='6/30/2014', freq='B')
331+
df = DataFrame(1, index=business_dates, columns=['a', 'b'])
332+
# get the first, fourth and last two business days for each month
333+
result = df.groupby((df.index.year, df.index.month)).nth([0, 3, -2, -1])
334+
expected_dates = pd.to_datetime(['2014/4/1', '2014/4/4', '2014/4/29', '2014/4/30',
335+
'2014/5/1', '2014/5/6', '2014/5/29', '2014/5/30',
336+
'2014/6/2', '2014/6/5', '2014/6/27', '2014/6/30'])
337+
expected = DataFrame(1, columns=['a', 'b'], index=expected_dates)
338+
assert_frame_equal(result, expected)
339+
316340
def test_grouper_index_types(self):
317341
# related GH5375
318342
# groupby misbehaving when using a Floatlike index

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy