Skip to content

Commit 6195a78

Browse files
authored
feat: allow disabling tldextract HTTP requests (supertokens#563)
- Workaround till tldextract#233 is implemented - Adds flag `FLAG_tldextract_disable_http` to control disabling HTTP requests - Adds `pyfakefs` to use in tests - Pins `tldextract` to current major version
1 parent d00c0a6 commit 6195a78

File tree

10 files changed

+97
-12
lines changed

10 files changed

+97
-12
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ __pycache__
88
releasePassword
99
apiPassword
1010
venv/
11-
env/
11+
./env/
1212
.env
1313
.DS_Store
1414
bin/

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88

99
## [unreleased]
10-
- Upgrades `pip` and `setuptools` in CI publish job
11-
- Also upgrades `poetry` and it's dependency - `clikit`
1210

1311
## [0.29.0] - 2025-03-03
12+
- Adds option to disable `tldextract` HTTP calls by setting `SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP=1`
13+
- Upgrades `pip` and `setuptools` in CI publish job
14+
- Also upgrades `poetry` and it's dependency - `clikit`
1415
- Migrates unit tests to use a containerized core
1516
- Updates `Makefile` to use a Docker `compose` setup step
1617
- Migrates unit tests from CircleCI to Github Actions

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ flask-cors==5.0.0
99
nest-asyncio==1.6.0
1010
pdoc3==0.11.0
1111
pre-commit==3.5.0
12+
pyfakefs==5.7.4
1213
pylint==3.2.7
1314
pyright==1.1.393
1415
python-dotenv==1.0.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
"PyJWT[crypto]>=2.5.0,<3.0.0",
118118
"httpx>=0.15.0,<1.0.0",
119119
"pycryptodome<3.21.0",
120-
"tldextract<5.1.3",
120+
"tldextract<6.0.0",
121121
"asgiref>=3.4.1,<4",
122122
"typing_extensions>=4.1.1,<5.0.0",
123123
"Deprecated<1.3.0",

supertokens_python/env/__init__.py

Whitespace-only changes.

supertokens_python/env/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from os import environ
2+
3+
from supertokens_python.env.utils import str_to_bool
4+
5+
6+
def FLAG_tldextract_disable_http():
7+
"""
8+
Disable HTTP calls from `tldextract`.
9+
"""
10+
val = environ.get("SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP", "0")
11+
12+
return str_to_bool(val)

supertokens_python/env/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def str_to_bool(val: str) -> bool:
2+
"""
3+
Convert ENV values to boolean
4+
"""
5+
return val.lower() in ("true", "t", "1")

supertokens_python/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
from urllib.parse import urlparse
3636

3737
from httpx import HTTPStatusError, Response
38-
from tldextract import extract # type: ignore
38+
from tldextract import TLDExtract
3939

40+
from supertokens_python.env.base import FLAG_tldextract_disable_http
4041
from supertokens_python.framework.django.framework import DjangoFramework
4142
from supertokens_python.framework.fastapi.framework import FastapiFramework
4243
from supertokens_python.framework.flask.framework import FlaskFramework
@@ -288,7 +289,16 @@ def get_top_level_domain_for_same_site_resolution(url: str) -> str:
288289
if hostname.startswith("localhost") or is_an_ip_address(hostname):
289290
return "localhost"
290291

291-
parsed_url: Any = extract(hostname, include_psl_private_domains=True)
292+
extract = TLDExtract(fallback_to_snapshot=True, include_psl_private_domains=True)
293+
# Explicitly disable HTTP calls, use snapshot bundled into library
294+
if FLAG_tldextract_disable_http():
295+
extract = TLDExtract(
296+
suffix_list_urls=(), # Ensures no HTTP calls
297+
fallback_to_snapshot=True,
298+
include_psl_private_domains=True,
299+
)
300+
301+
parsed_url: Any = extract(hostname)
292302
if parsed_url.domain == "": # type: ignore
293303
# We need to do this because of https://github.com/supertokens/supertokens-python/issues/394
294304
if hostname.endswith(".amazonaws.com") and parsed_url.suffix == hostname:

tests/test_utils.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1+
import os
12
import threading
3+
from contextlib import ExitStack
24
from typing import Any, Dict, List, Union
5+
from unittest.mock import patch
36

4-
import pytest
7+
from pytest import mark, param, raises
58
from supertokens_python.utils import (
69
RWMutex,
710
get_top_level_domain_for_same_site_resolution,
811
humanize_time,
912
is_version_gte,
1013
)
1114

12-
from tests.utils import is_subset
15+
from tests.utils import is_subset, outputs
1316

1417

15-
@pytest.mark.parametrize(
18+
@mark.parametrize(
1619
"version,min_minor_version,is_gte",
1720
[
1821
(
@@ -72,7 +75,7 @@ def test_util_is_version_gte(version: str, min_minor_version: str, is_gte: bool)
7275
HOUR = 60 * MINUTE
7376

7477

75-
@pytest.mark.parametrize(
78+
@mark.parametrize(
7679
"ms,out",
7780
[
7881
(1 * SECOND, "1 second"),
@@ -91,7 +94,7 @@ def test_humanize_time(ms: int, out: str):
9194
assert humanize_time(ms) == out
9295

9396

94-
@pytest.mark.parametrize(
97+
@mark.parametrize(
9598
"d1,d2,result",
9699
[
97100
({"a": {"b": [1, 2]}, "c": 1}, {"c": 1}, True),
@@ -176,7 +179,7 @@ def balance_is_valid():
176179
assert actual_balance == expected_balance, "Incorrect account balance"
177180

178181

179-
@pytest.mark.parametrize(
182+
@mark.parametrize(
180183
"url,res",
181184
[
182185
("http://localhost:3001", "localhost"),
@@ -196,3 +199,41 @@ def balance_is_valid():
196199
)
197200
def test_tld_for_same_site(url: str, res: str):
198201
assert get_top_level_domain_for_same_site_resolution(url) == res
202+
203+
204+
@mark.parametrize(
205+
["internet_disabled", "env_val", "expectation"],
206+
[
207+
param(True, "False", raises(RuntimeError), id="Internet disabled, flag unset"),
208+
param(True, "True", outputs("google.com"), id="Internet disabled, flag set"),
209+
param(False, "False", outputs("google.com"), id="Internet enabled, flag unset"),
210+
param(False, "True", outputs("google.com"), id="Internet enabled, flag set"),
211+
],
212+
)
213+
def test_tldextract_http_toggle(
214+
internet_disabled: bool,
215+
env_val: str,
216+
expectation: Any,
217+
# pyfakefs fixture, mocks the filesystem
218+
# Mocking `tldextract`'s cache path does not work in repeated tests
219+
fs: Any,
220+
):
221+
import socket
222+
223+
# Disable sockets, will raise errors on HTTP calls
224+
socket_patch = patch.object(socket, "socket", side_effect=RuntimeError)
225+
environ_patch = patch.dict(
226+
os.environ,
227+
{"SUPERTOKENS_TLDEXTRACT_DISABLE_HTTP": env_val},
228+
)
229+
230+
stack = ExitStack()
231+
stack.enter_context(environ_patch)
232+
if internet_disabled:
233+
stack.enter_context(socket_patch)
234+
235+
# if `expectation` is raises, checks for raise
236+
# if `outputs`, value used in `assert` statement
237+
with stack, expectation as expected_output:
238+
output = get_top_level_domain_for_same_site_resolution("https://google.com")
239+
assert output == expected_output

tests/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
# Import AsyncMock
1818
import sys
19+
from contextlib import contextmanager
1920
from datetime import datetime
2021
from functools import lru_cache
2122
from http.cookies import SimpleCookie
@@ -487,3 +488,17 @@ async def create_users(
487488
await manually_create_or_update_user(
488489
"public", user["provider"], user["userId"], user["email"], True, None
489490
)
491+
492+
493+
@contextmanager
494+
def outputs(val: Any):
495+
"""
496+
Outputs a value to assert.
497+
498+
Usage:
499+
@mark.parametrize(["input", "expectation"], [(1, outputs(1)), (0, raises(Exception))])
500+
def test_fn(input, expectation):
501+
with expectation as expected_output:
502+
assert 1 / input == expected_output
503+
"""
504+
yield val

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy