Skip to content

Commit 74b0660

Browse files
authored
Use sorted indices (#43)
* Use sorted indices * changes
1 parent 73e5e64 commit 74b0660

File tree

3 files changed

+50
-9
lines changed

3 files changed

+50
-9
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Change Logs
55
0.5.0
66
+++++
77

8+
* :pr:`43`: improves reproducibility of function train_test_apart_stratify
89
* :pr:`33`: removes pyquickhelper dependency
910
* :pr:`30`: fix compatiblity with pandas 2.0
1011

_unittests/ut_df/test_connex_split_cat.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ def test_cat_strat(self):
3737
lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError
3838
)
3939

40+
def test_cat_strat_sorted(self):
41+
df = pandas.DataFrame(
42+
[
43+
dict(a=1, b="e"),
44+
dict(a=2, b="e"),
45+
dict(a=4, b="f"),
46+
dict(a=8, b="f"),
47+
dict(a=32, b="f"),
48+
dict(a=16, b="f"),
49+
]
50+
)
51+
52+
train, test = train_test_apart_stratify(
53+
df, group="a", stratify="b", test_size=0.5, sorted_indices=True
54+
)
55+
self.assertEqual(train.shape[1], test.shape[1])
56+
self.assertEqual(train.shape[0] + test.shape[0], df.shape[0])
57+
c1 = Counter(train["b"])
58+
c2 = Counter(train["b"])
59+
self.assertEqual(c1, c2)
60+
61+
self.assertRaise(
62+
lambda: train_test_apart_stratify(
63+
df, group=None, stratify="b", test_size=0.5, sorted_indices=True
64+
),
65+
ValueError,
66+
)
67+
self.assertRaise(
68+
lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError
69+
)
70+
4071
def test_cat_strat_multi(self):
4172
df = pandas.DataFrame(
4273
[

pandas_streaming/df/connex_split.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import Counter
22
from logging import getLogger
3+
from typing import Optional, Tuple
34
import pandas
45
import numpy
56
from .dataframe_helpers import dataframe_shuffle
@@ -447,14 +448,15 @@ def double_merge(d):
447448

448449

449450
def train_test_apart_stratify(
450-
df,
451+
df: pandas.DataFrame,
451452
group,
452-
test_size=0.25,
453-
train_size=None,
454-
stratify=None,
455-
force=False,
456-
random_state=None,
457-
):
453+
test_size: Optional[float] = 0.25,
454+
train_size: Optional[float] = None,
455+
stratify: Optional[str] = None,
456+
force: bool = False,
457+
random_state: Optional[int] = None,
458+
sorted_indices: bool = False,
459+
) -> Tuple["StreamingDataFrame", "StreamingDataFrame"]: # noqa: F821
458460
"""
459461
This split is for a specific case where data is linked
460462
in one way. Let's assume we have two ids as we have
@@ -472,6 +474,8 @@ def train_test_apart_stratify(
472474
:param force: if True, tries to get at least one example on the test side
473475
for each value of the column *stratify*
474476
:param random_state: seed for random generators
477+
:param sorted_indices: sort index first,
478+
see issue `41 <https://github.com/sdpython/pandas-streaming/issues/41>`
475479
:return: Two see :class:`StreamingDataFrame
476480
<pandas_streaming.df.dataframe.StreamingDataFrame>`, one
477481
for train, one for test.
@@ -538,10 +542,15 @@ def train_test_apart_stratify(
538542

539543
split = {}
540544
for _, k in sorted_hist:
541-
not_assigned = [c for c in ids[k] if c not in split]
545+
indices = sorted(ids[k]) if sorted_indices else ids[k]
546+
not_assigned, assigned = [], []
547+
for c in indices:
548+
if c in split:
549+
assigned.append(c)
550+
else:
551+
not_assigned.append(c)
542552
if len(not_assigned) == 0:
543553
continue
544-
assigned = [c for c in ids[k] if c in split]
545554
nb_test = sum(split[c] for c in assigned)
546555
expected = min(len(ids[k]), int(test_size * len(ids[k]) + 0.5)) - nb_test
547556
if force and expected == 0 and nb_test == 0:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy