From ae450357266abf7dfad44e5f0cfc381569ab0105 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 8 Apr 2025 13:56:42 -0700 Subject: [PATCH 1/4] Foreach_map tutorial --- recipes_source/foreach_map.py | 174 +++++++++++++++++++++++++++++++ recipes_source/recipes_index.rst | 8 ++ 2 files changed, 182 insertions(+) create mode 100644 recipes_source/foreach_map.py diff --git a/recipes_source/foreach_map.py b/recipes_source/foreach_map.py new file mode 100644 index 00000000000..22071525d75 --- /dev/null +++ b/recipes_source/foreach_map.py @@ -0,0 +1,174 @@ +""" +(beta) Explicit horizontal fusion with foreach_map and torch.compile +============================================================ + +**Author:** `Michael Lazos `_ +""" + +######################################################### +# Horizontal fusion is a key optimization in ML compilers. In eager, +# this is typically expressed using the torch._foreach* ops which paralellizes +# operations across a list of tensors. However, supporting all possible permuatations +# of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map +# allows conversion of any pointwise op in torch to a horiztonally fused foreach +# variant. In this tutorial, we will demonstrate how implement the Adam optimizer +# with foreach_map and generate a fully fused kernel. +# +# +# .. note:: +# +# This tutorial requires PyTorch 2.6.0 or later. + +##################################################################### +# Model Setup +# ~~~~~~~~~~~~~~~~~~~~~ +# For this example, we'll use a simple sequence of linear layers. +# We instantiate an independent copy to compare the two optimizer implementations. +# + +# exit cleanly if we are on a device that doesn't support ``torch.compile`` +if torch.cuda.get_device_capability() < (7, 0): + print("Exiting because torch.compile is not supported on this device.") + import sys + sys.exit(0) + +import torch + +# Create simple model +model = torch.nn.Sequential( + *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] +) +model_copy = torch.nn.Sequential( + *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] +) +input = torch.rand(1024, device="cuda") + +# run forward pass +output = model(input) +output_copy = model_copy(input) + +# run backward to populate the grads for our optimizer below +output.sum().backward() +output_copy.sum().backward() + +##################################################################### +# Helper functions for foreach_map implementation +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In this section, we'll begin out implementation of the Adam optimizer. +# +from torch._higher_order_ops.foreach_map import foreach_map + +# Helper function to extract optimizer states from a torch.optim.Adam instance +def get_inputs(optim): + steps = [] + params = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + for group in optim.param_groups: + for p in group["params"]: + params.append(p) + grads.append(p.grad) + state = optim.state[p] + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + steps.append(state["step"]) + + return steps, params, exp_avgs, exp_avg_sqs + + +# Functions to update the different optimizer states +def update_exp_avg_sq(exp_avg_sq, grad, beta2): + return exp_avg_sq.mul(beta2).addcmul(grad, grad, value=1 - beta2) + +def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps): + bias_correction1 = 1 - torch.pow(beta1, step) + bias_correction2 = (1 - torch.pow(beta2, step)).sqrt() + step_size = (lr / bias_correction1).neg() + denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size) + return torch.add(param, torch.div(exp_avg, denom)) + +# Our full adam implementation +def foreach_map_adam( + steps, + params, + exp_avgs, + exp_avg_sqs, + weight_decay=0, + beta1=0.9, + beta2=0.999, + lr=1e-3, + eps=1e-8, +): + with torch.no_grad(): + grads = [param.grad for param in params] + # update step + updated_steps = foreach_map(lambda x: x + 1, steps) + torch._foreach_copy_(steps, updated_steps) + + if weight_decay != 0: + foreach_map(torch.add, (grads,), alpha=weight_decay) + + # Higher-order operators (HOPs) cannot have multiple outputs at the moment + # need to call foreach_map once for each output + exp_avgs_updated = foreach_map(torch.lerp, exp_avgs, grads, 1 - beta1) + exp_avgs_sq_updated = foreach_map(update_exp_avg_sq, exp_avg_sqs, grads, beta2) + params_updated = foreach_map( + update_param, + params, + steps, + exp_avgs_updated, + exp_avgs_sq_updated, + beta1, + beta2, + lr, + eps, + ) + # Higher-order operators (HOPs) don't support input mutation today + # so manually update the states in-place + torch._foreach_copy_(exp_avgs, exp_avgs_updated) + torch._foreach_copy_(exp_avg_sqs, exp_avgs_sq_updated) + torch._foreach_copy_(params, params_updated) + return + +##################################################################### +# Setting up and running the compiled kernel +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In this section, we'll run our Adam optimizer +# and compare the results +# +# .. note:: +# +# ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher. +opt_eager = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01)) +opt_eager_copy = torch.optim.Adam(model_copy.parameters(), lr=torch.tensor(0.01)) + +# warm up the optimizer state dict +opt_eager.step() +opt_eager_copy.step() + +inputs = get_inputs(opt_eager_copy) +compiled_adam = torch.compile(foreach_map_adam) + +# optionally view the output code +torch._logging.set_logs(output_code=True) + +# Warmup runs to compile the function +for _ in range(5): + opt_eager.step() + compiled_adam(*inputs) + +for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]): + torch.allclose(eager_p, compile_p) + +###################################################################### +# Conclusion +# ~~~~~~~~~~ +# In this tutorial, we implemented a custom fully fused Adam optimizer using foreach_map. +# +# See also: +# +# * `Compiled optimizer tutorial `__ - an intro into the compiled optimizer. +# * `Compiling the optimizer with PT2 `__ - deeper technical details on the compiled optimizer. diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index f136c4b9c67..c52e92f881d 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -335,6 +335,14 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/compiling_optimizer_lr_scheduler.html :tags: Model-Optimization +.. (beta) Explicit horizontal fusion with foreach_map and torch.compile +.. customcarditem:: + :header: (beta) Explicit horizontal fusion with foreach_map and torch.compile + :card_description: Horizontally fuse pointwise ops with torch.compile + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/foreach_map.py + :tags: Model-Optimization + .. Using User-Defined Triton Kernels with ``torch.compile`` .. customcarditem:: From a96d0c477b3ef701e992932142970ed5a2719180 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 9 Apr 2025 13:01:04 -0700 Subject: [PATCH 2/4] Update recipes_source/foreach_map.py Co-authored-by: William Wen --- recipes_source/foreach_map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/foreach_map.py b/recipes_source/foreach_map.py index 22071525d75..0c32a567eba 100644 --- a/recipes_source/foreach_map.py +++ b/recipes_source/foreach_map.py @@ -12,7 +12,7 @@ # of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map # allows conversion of any pointwise op in torch to a horiztonally fused foreach # variant. In this tutorial, we will demonstrate how implement the Adam optimizer -# with foreach_map and generate a fully fused kernel. +# with ``foreach_map`` to generate a fully fused kernel. # # # .. note:: From da21908bb84c57e83e806758f64f85d7d09eed2c Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 9 Apr 2025 13:05:10 -0700 Subject: [PATCH 3/4] Udpates to tutorial --- recipes_source/foreach_map.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/recipes_source/foreach_map.py b/recipes_source/foreach_map.py index 0c32a567eba..cb72d02bc4f 100644 --- a/recipes_source/foreach_map.py +++ b/recipes_source/foreach_map.py @@ -7,17 +7,17 @@ ######################################################### # Horizontal fusion is a key optimization in ML compilers. In eager, -# this is typically expressed using the torch._foreach* ops which paralellizes -# operations across a list of tensors. However, supporting all possible permuatations +# this is typically expressed using the torch._foreach* ops which parallelizes +# operations across a list of tensors. However, supporting all possible permutations # of arguments is quite difficult (e.g. mixtures of scalars and lists). Foreach_map -# allows conversion of any pointwise op in torch to a horiztonally fused foreach -# variant. In this tutorial, we will demonstrate how implement the Adam optimizer +# allows conversion of any pointwise op in ``torch`` to a horiztonally fused foreach +# variant. In this tutorial, we will demonstrate how to implement the Adam optimizer # with ``foreach_map`` to generate a fully fused kernel. # # # .. note:: # -# This tutorial requires PyTorch 2.6.0 or later. +# This tutorial requires PyTorch 2.7.0 or later. ##################################################################### # Model Setup @@ -25,6 +25,7 @@ # For this example, we'll use a simple sequence of linear layers. # We instantiate an independent copy to compare the two optimizer implementations. # +import torch # exit cleanly if we are on a device that doesn't support ``torch.compile`` if torch.cuda.get_device_capability() < (7, 0): @@ -32,8 +33,6 @@ import sys sys.exit(0) -import torch - # Create simple model model = torch.nn.Sequential( *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] @@ -55,7 +54,7 @@ # Helper functions for foreach_map implementation # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# In this section, we'll begin out implementation of the Adam optimizer. +# In this section, we'll begin our implementation of the Adam optimizer. # from torch._higher_order_ops.foreach_map import foreach_map @@ -89,7 +88,7 @@ def update_param(param, step, exp_avg, exp_avg_sq, beta1, beta2, lr, eps): denom = (exp_avg_sq.sqrt() / (bias_correction2 * step_size)).add(eps / step_size) return torch.add(param, torch.div(exp_avg, denom)) -# Our full adam implementation +# Our full Adam implementation def foreach_map_adam( steps, params, @@ -166,7 +165,11 @@ def foreach_map_adam( ###################################################################### # Conclusion # ~~~~~~~~~~ -# In this tutorial, we implemented a custom fully fused Adam optimizer using foreach_map. +# In this tutorial, we successfully implemented a custom fully-fused Adam optimizer using foreach_map. +# By leveraging the power of foreach_map and torch.compile, we were able to create an optimized version of the Adam +# optimizer that can be used in various machine learning applications. This tutorial provides a comprehensive guide +# on how to use foreach_map and torch.compile to optimize machine learning models, and serves as a +# valuable resource for developers looking to improve the performance of their models with horizontal fusion. # # See also: # From ada986dc3837fceb278eeacbcf4188f02fac8095 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 9 Apr 2025 13:12:13 -0700 Subject: [PATCH 4/4] More updates --- recipes_source/foreach_map.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/recipes_source/foreach_map.py b/recipes_source/foreach_map.py index cb72d02bc4f..b8bf0aa2836 100644 --- a/recipes_source/foreach_map.py +++ b/recipes_source/foreach_map.py @@ -162,6 +162,27 @@ def foreach_map_adam( for eager_p, compile_p in zip(opt_eager.param_groups[0]["params"], opt_eager_copy.param_groups[0]["params"]): torch.allclose(eager_p, compile_p) +# Benchmark performance + + # Let's define a helpful benchmarking function: +import torch.utils.benchmark as benchmark + +def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + +eager_runtime = benchmark_torch_function_in_microseconds(opt_eager.step) +compiled_runtime = benchmark_torch_function_in_microseconds(lambda: compiled_adam(*inputs)) + +assert eager_runtime > compiled_runtime + +print(f"eager runtime: {eager_runtime}us") +print(f"compiled runtime: {compiled_runtime}us") + + + ###################################################################### # Conclusion # ~~~~~~~~~~ pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy