Skip to content

TypeError: cannot pickle '_thread.lock' object in TensorFlow 2.4 #46556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
rpasricha opened this issue Jan 20, 2021 · 11 comments
Closed

TypeError: cannot pickle '_thread.lock' object in TensorFlow 2.4 #46556

rpasricha opened this issue Jan 20, 2021 · 11 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues TF 2.4 for issues related to TF 2.4 type:bug Bug

Comments

@rpasricha
Copy link

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Debian 10
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): Binary
  • TensorFlow version (use command below): v2.4.0-0-g582c8d236cb 2.4.0
  • Python version: 3.7.9
  • Bazel version (if compiling from source): n/a
  • GCC/Compiler version (if compiling from source): n/a
  • CUDA/cuDNN version: n/a
  • GPU model and memory: n/a

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
Running a simple training process with MultiWorkerMirroredStrategy fails with TypeError: can't pickle _thread.lock objects.

Describe the expected behavior
The training should proceed without errors.

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

The example needs to run in a distributed environment to reproduce the issue, so save the script in a file and run it in 3 different terminals.

TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "chief", "index": 0}}' python script.py 
TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "worker", "index": 0}}' python script.py
TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "worker", "index": 1}}' python script.py
import tensorflow as tf
import tensorflow_datasets as tfds
 
buffer_size = 10000
batch_size = 64
learning_rate = 1e-4
 
def input_fn(mode, input_context=None):
  tfds.disable_progress_bar()
  datasets, _ = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_dataset = ( 
      datasets['train']
      if mode == tf.estimator.ModeKeys.TRAIN else datasets['test'])
 
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255 
    return image, label
 
  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(buffer_size).batch(batch_size)
 
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)
 
  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
 
  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=learning_rate)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels,
                                                                  logits)
  loss = tf.reduce_sum(loss) * (1. / batch_size)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)
 
  logging_hook = tf.estimator.LoggingTensorHook({'loss': loss}, every_n_iter=10)
 
  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      training_hooks=[logging_hook],
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))
 
 
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
 
config = tf.estimator.RunConfig(train_distribute=strategy)
 
classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
 
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn))

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

Full logs:

TF_CONFIG='{"cluster": {"chief": ["localhost:2222"], "worker": ["localhost:2223", "localhost:2224"]}, "task": {"type": "worker", "index": 1}}' python script.py
WARNING:tensorflow:From script.py:68: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.
Instructions for updating:                                                                
use distribute.MultiWorkerMirroredStrategy instead                                                       
2021-01-20 18:24:44.477611: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-01-20 18:24:44.479538: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
2021-01-20 18:24:44.491607: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job chief -> {0 -> localhost:2222}
2021-01-20 18:24:44.491654: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job worker -> {0 -> localhost:2223, 1 -> localhost:2224}
2021-01-20 18:24:44.492211: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:411] Started server with target: grpc://localhost:2224
Traceback (most recent call last):                                                                                           
  File "script.py", line 73, in <module>                                                  
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)                                
  File "/opt/conda/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 183, in __init__
    config, model_dir)                                                                          
  File "/opt/conda/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1832, in maybe_overwrite_model_dir_and_session_config
    config = run_config.RunConfig.replace(config, session_config=session_config)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/run_config.py", line 923, in replace
    copy.deepcopy(self),                                                             
  File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy                    
    y = _reconstruct(x, memo, *rv)                                                                 
  File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct         
    state = deepcopy(state, memo)                                                              
  File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy            
    y = copier(x, memo)                                                                            
  File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict       
    y[deepcopy(key, memo)] = deepcopy(value, memo)          
  File "/opt/conda/lib/python3.7/copy.py", line 161, in deepcopy                      
    y = copier(memo)                                                                                                       
  File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1542, in __deepcopy__
    setattr(result, k, copy.deepcopy(v, memo))                                                    
  File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy                                             
    y = _reconstruct(x, memo, *rv)                                                
  File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct                                             
    state = deepcopy(state, memo)                                                   
  File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy                             
    y = copier(x, memo)                                                                    
  File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict              
    y[deepcopy(key, memo)] = deepcopy(value, memo)                                                                   
  File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy                      
    y = _reconstruct(x, memo, *rv)                                                             
  File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct                                             
    state = deepcopy(state, memo)                                                   
  File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy            
    y = copier(x, memo)                                                                      
  File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict                               
    y[deepcopy(key, memo)] = deepcopy(value, memo)                            
  File "/opt/conda/lib/python3.7/copy.py", line 180, in deepcopy                       
    y = _reconstruct(x, memo, *rv)                                                                                             
  File "/opt/conda/lib/python3.7/copy.py", line 281, in _reconstruct                       
    state = deepcopy(state, memo)                                                                          
  File "/opt/conda/lib/python3.7/copy.py", line 150, in deepcopy                 
    y = copier(x, memo)                                                                   
  File "/opt/conda/lib/python3.7/copy.py", line 241, in _deepcopy_dict                     
    y[deepcopy(key, memo)] = deepcopy(value, memo)                                                                     
  File "/opt/conda/lib/python3.7/copy.py", line 169, in deepcopy                       
    rv = reductor(4)                                                                                                 
TypeError: can't pickle _thread.lock objects  
@amahendrakar
Copy link
Contributor

amahendrakar commented Jan 20, 2021

Was able to reproduce the issue with TF v2.4 and TF-nightly. Please check the attached screenshot for reference.

Screenshot 2021-01-20 at 2 29 33 PM

Whereas with TF v2.3 the error is W tensorflow/core/common_runtime/eager/context.cc:566] Unable to destroy server_ object, so releasing instead. Servers don't support clean shutdown.

Similar to issue #45918

Thanks!

@amahendrakar amahendrakar added comp:dist-strat Distribution Strategy related issues TF 2.4 for issues related to TF 2.4 labels Jan 20, 2021
@nikitamaia
Copy link
Member

Hi @rpasricha, to narrow down the issue here I just tried running the sample provided in the Multi-worker training with Estimator tutorial. I seem to be getting the same can't pickle _thread.lock objects error. Can you test with that code and let me know if you see the same?

@haitong
Copy link

haitong commented Jan 28, 2021

I am having the same problem in TF2.4.0 and TF2.4.1 with the same stacktrace. I tried both python3.6/python3.7 with tf.distribute.experimental.MultiWorkerMirroredStrategy() and tf.distribute.MultiWorkerMirroredStrategy(). Any updates?

@Deprecated-GeforceTesla

Also facing the issue. Is there a workaround for this issue?

@haitong
Copy link

haitong commented Jan 29, 2021

Also facing the issue. Is there a workaround for this issue?

The problem is this threading introduced in tf2.4 https://github.com/tensorflow/tensorflow/blob/r2.4/tensorflow/python/distribute/collective_all_reduce_strategy.py#L22

I find a workaround by explicitly overwriting __deepcopy__ for both RunConfig and MultiWorkerMirroredStrategy:

class MyRunConfig(tf.estimator.RunConfig):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __deepcopy__(self, memo={}):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            if '_distribute' in k:
                setattr(result, k, v)
            else:
                setattr(result, k, deepcopy(v, memo))
        return result

class MyDistributeStrategy(tf.distribute.MultiWorkerMirroredStrategy):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __deepcopy__(self, memo={}):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            if '_extend' in k:
                setattr(result, k, v)
            else:
                setattr(result, k, deepcopy(v, memo))
        return result

In your code, just use the above two class for RunConfig and tf.distribute.MultiWorkerMirroredStrategy as follows

strategy = MyDistributeStrategy()
config = MyRunConfig(train_distribute=strategy)

@crccw
Copy link
Member

crccw commented Jan 29, 2021

Could you try disable eager execution via: tf.compat.v1.disable_eager_execution()

it needs to be called at the beginning of the program.

@haitong
Copy link

haitong commented Jan 30, 2021

Found another workaround

Add the following code to explicitly disable health check

from tensorflow.python.distribute.collective_all_reduce_strategy import CollectiveAllReduceExtended
CollectiveAllReduceExtended._enable_check_health = False

@rpasricha
Copy link
Author

I was able to get around the issue by disabling eager execution, thanks.

@nikitamaia The tutorial only runs the code on a single node, the issue only arises when doing distributed training with estimator + multi worker mirrored strategy.

@nikitamaia
Copy link
Member

@rpasricha ah yes, I meant to say to run the code from the tutorial but in your multi node environment (not in colab). Regardless, seems we have a workaround for now.

@nikitamaia
Copy link
Member

Closing this issue since it is a duplicate of #45918. For further updates please refer to the other thread so we can track this in one place.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

TensorFlow-Docs-Copybara pushed a commit to tensorflow/docs that referenced this issue Mar 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues TF 2.4 for issues related to TF 2.4 type:bug Bug
Projects
None yet
Development

No branches or pull requests

6 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy