1
RNN-T Training Iteration
Load DataForward
Calc Loss
Zero Grad
Backward
Optimizer
cpu2
1.67 s
Load Data
Forward
Calc Loss
Zero Grad
Backward
Optimizer
gpu3
1.54 s
Details:
RNN-T Training Iteration
Runtime
1.67 s
Start time: 01:39:50.604.604527
Operations
216,681
Trace size: 81.1 MB
Open trace with Perfetto
Insights
CPU-Bound
Code:
/home/ubuntu/algorithmic-efficiency/submission_runner.py:236
1r"""Run a submission on a single workload.
2
3Example command:
4
5python3 submission_runner.py \
6 --workload=mnist \
7 --framework=jax \
8 --submission_path=reference_submissions/mnist/mnist_jax/submission.py \
9 --tuning_ruleset=external \
10 --tuning_search_space=reference_submissions/mnist/tuning_search_space.json \
11 --num_tuning_trials=3
12"""
13import importlib
14import inspect
15import json
16import os
17import struct
18import time
19from typing import Optional, Tuple
20
21from absl import app
22from absl import flags
23from absl import logging
24import tensorflow as tf
25import torch
26import torch.distributed as dist
27
28import datetime
29
30from algorithmic_efficiency import halton
31from algorithmic_efficiency import random_utils as prng
32from algorithmic_efficiency import spec
33
34
35# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
36# it unavailable to JAX.
37tf.config.experimental.set_visible_devices([], 'GPU')
38
39# TODO(znado): make a nicer registry of workloads that lookup in.
40BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/'
41
42# Workload_path will be appended by '_pytorch' or '_jax' automatically.
43WORKLOADS = {
44 'cifar': {
45 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload'
46 },
47 'criteo1tb': {
48 'workload_path': 'criteo1tb/criteo1tb',
49 'workload_class_name': 'Criteo1TbDlrmSmallWorkload'
50 },
51 'fastmri': {
52 'workload_path': 'fastmri/fastmri',
53 'workload_class_name': 'FastMRIWorkload'
54 },
55 'imagenet_resnet': {
56 'workload_path': 'imagenet_resnet/imagenet',
57 'workload_class_name': 'ImagenetResNetWorkload'
58 },
59 'imagenet_vit': {
60 'workload_path': 'imagenet_vit/imagenet',
61 'workload_class_name': 'ImagenetVitWorkload'
62 },
63 'librispeech': {
64 'workload_path': 'librispeech/librispeech',
65 'workload_class_name': 'LibriSpeechWorkload'
66 },
67 'mnist': {
68 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload'
69 },
70 'ogbg': {
71 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'
72 },
73 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'},
74}
75
76flags.DEFINE_string(
77 'submission_path',
78 'reference_submissions/mnist/mnist_jax/submission.py',
79 'The relative path of the Python file containing submission functions. '
80 'NOTE: the submission dir must have an __init__.py file!')
81flags.DEFINE_string(
82 'workload',
83 'mnist',
84 help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}'
85)
86flags.DEFINE_enum(
87 'tuning_ruleset',
88 'external',
89 enum_values=['external', 'self'],
90 help='Which tuning ruleset to use.')
91flags.DEFINE_string(
92 'tuning_search_space',
93 'reference_submissions/mnist/tuning_search_space.json',
94 'The path to the JSON file describing the external tuning search space.')
95flags.DEFINE_integer('num_tuning_trials',
96 20,
97 'The number of external hyperparameter trials to run.')
98flags.DEFINE_string('data_dir', '~/tensorflow_datasets/', 'Dataset location')
99flags.DEFINE_enum(
100 'framework',
101 None,
102 enum_values=['jax', 'pytorch'],
103 help='Whether to use Jax or Pytorch for the submission. Controls among '
104 'other things if the Jax or Numpy RNG library is used for RNG.')
105
106FLAGS = flags.FLAGS
107
108
109def convert_filepath_to_module(path: str):
110 base, extension = os.path.splitext(path)
111
112 if extension != '.py':
113 raise ValueError(f'Path: {path} must be a python file (*.py)')
114
115 return base.replace('/', '.')
116
117
118def import_workload(workload_path: str,
119 workload_class_name: str,
120 return_class=False) -> spec.Workload:
121 """Import and add the workload to the registry.
122
123 This importlib loading is nice to have because it allows runners to avoid
124 installing the dependencies of all the supported frameworks. For example, if
125 a submitter only wants to write Jax code, the try/except below will catch
126 the import errors caused if they do not have the PyTorch dependencies
127 installed on their system.
128
129 Args:
130 workload_path: the path to the `workload.py` file to load.
131 workload_class_name: the name of the Workload class that implements the
132 `Workload` abstract class in `spec.py`.
133 return_class: if true, then the workload class is returned instead of the
134 instantiated object. Useful for testing when methods need to be overriden.
135 """
136
137 # Remove the trailing '.py' and convert the filepath to a Python module.
138 workload_path = convert_filepath_to_module(workload_path)
139
140 # Import the workload module.
141 workload_module = importlib.import_module(workload_path)
142 # Get everything defined in the workload module (including our class).
143 workload_module_members = inspect.getmembers(workload_module)
144 workload_class = None
145 for name, value in workload_module_members:
146 if name == workload_class_name:
147 workload_class = value
148 break
149 if workload_class is None:
150 raise ValueError(
151 f'Could not find member {workload_class_name} in {workload_path}. '
152 'Make sure the Workload class is spelled correctly and defined in '
153 'the top scope of the module.')
154 if return_class:
155 return workload_class
156 return workload_class()
157
158
159# Example reference implementation showing how to use the above functions
160# together.
161def train_once(workload: spec.Workload,
162 global_batch_size: int,
163 data_dir: str,
164 init_optimizer_state: spec.InitOptimizerFn,
165 update_params: spec.UpdateParamsFn,
166 data_selection: spec.DataSelectionFn,
167 hyperparameters: Optional[spec.Hyperparameters],
168 rng: spec.RandomState) -> Tuple[spec.Timing, spec.Steps]:
169 data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
170
171 # Workload setup.
172 logging.info('Initializing dataset.')
173 input_queue = workload.build_input_queue(
174 data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size)
175 logging.info('Initializing model.')
176 model_params, model_state = workload.init_model_fn(model_init_rng)
177 logging.info('Initializing optimizer.')
178 optimizer_state = init_optimizer_state(workload,
179 model_params,
180 model_state,
181 hyperparameters,
182 opt_init_rng)
183
184 # Bookkeeping.
185 goal_reached = False
186 is_time_remaining = True
187 last_eval_time = 0
188 accumulated_submission_time = 0
189 eval_results = []
190 global_step = 0
191 training_complete = False
192 global_start_time = time.time()
193
194
195# Hotline profiling
196 import hotline
197 from IPython import embed
198 import datetime
199
200 last_time = datetime.datetime.now()
201 print(last_time)
202 model_params = torch.nn.DataParallel(model_params)
203
204 quick_run = os.environ.get('HOTLINE_QUICK_RUN')
205 if False:
206 wait = 2
207 warmup = 2
208 active = 1
209 else:
210 wait = 20
211 warmup = 19
212 active = 1
213 # wait = 1
214 # warmup = 0
215 # active = 1
216 max_steps = wait + warmup + active
217
218 metadata = {
219 'model': 'RNN-T',
220 'dataset': 'LibriSpeech',
221 'batch_size': global_batch_size,
222 'optimizer': 'Adam',
223 'runtime': [],
224 }
225
226
227
228 torch_profiler = torch.profiler.profile(
229 activities=[
230 torch.profiler.ProfilerActivity.CPU,
231 torch.profiler.ProfilerActivity.CUDA],
232 schedule=torch.profiler.schedule(
233 wait=wait,
234 warmup=warmup,
235 active=active),
236 on_trace_ready=hotline.analyze(
237 model_params,
238 input_queue,
239 run_name='RNN',
240 test_accuracy=True,
241 output_dir='/home/dans/cpath',
242 metadata=metadata,
243 ),
244 record_shapes=False,
245 profile_memory=False,
246 with_stack=False
247 )
248
249
250 logging.info('Starting training loop.')
251 while (is_time_remaining and not goal_reached and not training_complete):
252
253 step_rng = prng.fold_in(rng, global_step)
254 data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)
255 start_time = time.time()
256 with hotline.annotate('Load Data'):
257 batch = data_selection(workload,
258 input_queue,
259 optimizer_state,
260 model_params,
261 hyperparameters,
262 global_step,
263 data_select_rng)
264 try:
265 optimizer_state, model_params, model_state = update_params(
266 workload=workload,
267 current_param_container=model_params,
268 current_params_types=workload.model_params_types,
269 model_state=model_state,
270 hyperparameters=hyperparameters,
271 batch=batch,
272 loss_type=workload.loss_type,
273 optimizer_state=optimizer_state,
274 eval_results=eval_results,
275 global_step=global_step,
276 rng=update_rng)
277
278
279
280 # Hotline profiling
281 this_time = datetime.datetime.now()
282 tdelta = this_time - last_time
283 logging.info(f'tdelta: {tdelta}')
284 metadata['runtime'].append(tdelta)
285 last_time = this_time
286 logging.info(f'global_step: {global_step}\n')
287 torch_profiler.step()
288 hotline.annotate.step()
289 if global_step >= max_steps:
290 import sys
291 sys.exit(0)
292 global_step += 1
293 continue
294
295
296
297
298 except spec.TrainingCompleteError:
299 training_complete = True
300 global_step += 1
301 current_time = time.time()
302 accumulated_submission_time += current_time - start_time
303 is_time_remaining = (
304 accumulated_submission_time < workload.max_allowed_runtime_sec)
305 # Check if submission is eligible for an untimed eval.
306 # if (current_time - last_eval_time >= workload.eval_period_time_sec or
307 # training_complete):
308 # latest_eval_result = workload.eval_model(global_batch_size,
309 # model_params,
310 # model_state,
311 # eval_rng,
312 # data_dir)
313 # logging.info('%.2fs \t%d \t%s',
314 # current_time - global_start_time,
315 # global_step,
316 # latest_eval_result)
317 # last_eval_time = current_time
318 # eval_results.append((global_step, latest_eval_result))
319 # goal_reached = workload.has_reached_goal(latest_eval_result)
320 metrics = {'eval_results': eval_results, 'global_step': global_step}
321 return accumulated_submission_time, metrics
322
323
324def score_submission_on_workload(workload: spec.Workload,
325 workload_name: str,
326 submission_path: str,
327 data_dir: str,
328 tuning_ruleset: str,
329 tuning_search_space: Optional[str] = None,
330 num_tuning_trials: Optional[int] = None):
331 # Remove the trailing '.py' and convert the filepath to a Python module.
332 submission_module_path = convert_filepath_to_module(submission_path)
333 submission_module = importlib.import_module(submission_module_path)
334
335 init_optimizer_state = submission_module.init_optimizer_state
336 update_params = submission_module.update_params
337 data_selection = submission_module.data_selection
338 global_batch_size = submission_module.get_batch_size(workload_name)
339
340 if tuning_ruleset == 'external':
341 # If the submission runner is responsible for hyperparameter tuning, load in
342 # the search space and generate a list of randomly selected hyperparameter
343 # settings from it.
344 if tuning_search_space is None:
345 raise ValueError(
346 'Must provide a tuning search space JSON file when using external '
347 'tuning.')
348 with open(tuning_search_space, 'r', encoding='UTF-8') as search_space_file:
349 tuning_search_space = halton.generate_search(
350 json.load(search_space_file), num_tuning_trials)
351 all_timings = []
352 all_metrics = []
353 for hi, hyperparameters in enumerate(tuning_search_space):
354 # Generate a new seed from hardware sources of randomness for each trial.
355 rng_seed = struct.unpack('I', os.urandom(4))[0]
356 rng = prng.PRNGKey(rng_seed)
357 # Because we initialize the PRNGKey with only a single 32 bit int, in the
358 # Jax implementation this means that rng[0] is all zeros, which means this
359 # could lead to unintentionally reusing the same seed of only rng[0] were
360 # ever used. By splitting the rng into 2, we mix the lower and upper 32
361 # bit ints, ensuring we can safely use either rng[0] or rng[1] as a random
362 # number.
363 rng, _ = prng.split(rng, 2)
364 logging.info('--- Tuning run %d/%d ---', hi + 1, num_tuning_trials)
365 timing, metrics = train_once(workload, global_batch_size, data_dir,
366 init_optimizer_state, update_params,
367 data_selection, hyperparameters, rng)
368 all_timings.append(timing)
369 all_metrics.append(metrics)
370 score = min(all_timings)
371 for ti in range(num_tuning_trials):
372 logging.info('Tuning trial %d/%d', ti + 1, num_tuning_trials)
373 logging.info('Hyperparameters: %s', tuning_search_space[ti])
374 logging.info('Metrics: %s', all_metrics[ti])
375 logging.info('Timing: %s', all_timings[ti])
376 logging.info('=' * 20)
377 else:
378 rng_seed = struct.unpack('q', os.urandom(8))[0]
379 rng = prng.PRNGKey(rng_seed)
380 # If the submission is responsible for tuning itself, we only need to run it
381 # once and return the total time.
382 score, _ = train_once(workload, global_batch_size, init_optimizer_state,
383 update_params, data_selection, None, rng)
384 # TODO(znado): record and return other information (number of steps).
385 return score
386
387
388def main(_):
389 # Check if distributed data parallel is used.
390 use_pytorch_ddp = 'LOCAL_RANK' in os.environ
391 if FLAGS.framework == 'pytorch':
392 # From the docs: "(...) causes cuDNN to benchmark multiple convolution
393 # algorithms and select the fastest."
394 torch.backends.cudnn.benchmark = True
395
396 if use_pytorch_ddp:
397 rank = int(os.environ['LOCAL_RANK'])
398 torch.cuda.set_device(rank)
399 # only log once (for local rank == 0)
400 if rank != 0:
401
402 def logging_pass(*args):
403 pass
404
405 logging.info = logging_pass
406 # initialize the process group
407 dist.init_process_group('nccl')
408
409 workload_metadata = WORKLOADS[FLAGS.workload]
410 # extend path according to framework
411 workload_metadata['workload_path'] = os.path.join(
412 BASE_WORKLOADS_DIR,
413 workload_metadata['workload_path'] + '_' + FLAGS.framework,
414 'workload.py')
415 workload = import_workload(
416 workload_path=workload_metadata['workload_path'],
417 workload_class_name=workload_metadata['workload_class_name'])
418
419 score = score_submission_on_workload(workload,
420 FLAGS.workload,
421 FLAGS.submission_path,
422 FLAGS.data_dir,
423 FLAGS.tuning_ruleset,
424 FLAGS.tuning_search_space,
425 FLAGS.num_tuning_trials)
426 logging.info('Final %s score: %f', FLAGS.workload, score)
427
428 if use_pytorch_ddp:
429 # cleanup
430 dist.destroy_process_group()
431
432
433if __name__ == '__main__':
434 app.run(main)
435
This workload used additional manual annotations to measure that Hotline's automatic annotation was 96.95% accurate.
Metadata:
KeyValue
nameRNN-T Training Iteration
typeroot
resources.cpu11.time.ts1685410791420744000
resources.cpu11.time.dur852391000
resources.cpu11.time.parent_is_longestfalse
resources.cpu11.time.runtime_str852 ms
resources.cpu2.time.ts1685410790604527000
resources.cpu2.time.dur1671969000
resources.cpu2.time.parent_is_longesttrue
resources.cpu2.time.runtime_str1.67 s
resources.cpu6.time.ts1685410790770475000
resources.cpu6.time.dur120671000
resources.cpu6.time.parent_is_longestfalse
resources.cpu6.time.runtime_str121 ms
resources.gpu10.time.ts1685410791208622000
resources.gpu10.time.dur206616000
resources.gpu10.time.parent_is_longestfalse
resources.gpu10.time.runtime_str207 ms
resources.gpu12.time.ts1685410791544330000
resources.gpu12.time.dur842639000
resources.gpu12.time.parent_is_longestfalse
resources.gpu12.time.runtime_str843 ms
resources.gpu13.time.ts1685410791544332000
resources.gpu13.time.dur825889000
resources.gpu13.time.parent_is_longestfalse
resources.gpu13.time.runtime_str826 ms
resources.gpu14.time.ts1685410791593676000
resources.gpu14.time.dur807460000
resources.gpu14.time.parent_is_longestfalse
resources.gpu14.time.runtime_str807 ms
resources.gpu15.time.ts1685410791593804000
resources.gpu15.time.dur808133000
resources.gpu15.time.parent_is_longestfalse
resources.gpu15.time.runtime_str808 ms
resources.gpu3.time.ts1685410791011158000
resources.gpu3.time.dur1538011000
resources.gpu3.time.parent_is_longesttrue
resources.gpu3.time.runtime_str1.54 s
resources.gpu4.time.ts1685410791579645000
resources.gpu4.time.dur854335000
resources.gpu4.time.parent_is_longestfalse
resources.gpu4.time.runtime_str854 ms
resources.gpu5.time.ts1685410791584691000
resources.gpu5.time.dur849361000
resources.gpu5.time.parent_is_longestfalse
resources.gpu5.time.runtime_str849 ms
resources.gpu7.time.ts1685410791115832000
resources.gpu7.time.dur265604000
resources.gpu7.time.parent_is_longestfalse
resources.gpu7.time.runtime_str266 ms
resources.gpu8.time.ts1685410791115832000
resources.gpu8.time.dur270873000
resources.gpu8.time.parent_is_longestfalse
resources.gpu8.time.runtime_str271 ms
resources.gpu9.time.ts1685410791208581000
resources.gpu9.time.dur206523000
resources.gpu9.time.parent_is_longestfalse
resources.gpu9.time.runtime_str207 ms
source_file_name/home/ubuntu/algorithmic-efficiency/submission_runner.py
source_file_num236
idrjdkRLSrsmjNo9BT
pretty_nameRNN-T Training Iteration
total_accuracy_str96.95%
trace_file/results/RNN/RNN.1.pt.trace.json
trace_disk_size81.1 MB
ui_source_code_path/home/ubuntu/cpath/results/ui/dist/traces/results/code/RNN/home/ubuntu/algorithmic-efficiency/submission_runner.py
config.trace_filepath/home/ubuntu/cpath/results/RNN.worker0.pt.trace.json
config.output_dir/home/ubuntu/cpath/results
config.ui_traces_path/home/ubuntu/cpath/results/ui/dist/traces/results/RNN
config.ui_model_path/home/ubuntu/cpath/results/ui/src/results/RNN.js
config.results_summary_csv_filepath/home/ubuntu/cpath/results/results_summary.csv
config.ui_source_code_path/home/ubuntu/cpath/results/ui/dist/traces/results/code/RNN
config.run_nameRNN
config.model_nameRNN-T
config.backendtorch
config.testfalse
config.is_test_accuracytrue
config.view_manual_annotationsfalse
config.source_file_name/home/ubuntu/algorithmic-efficiency/submission_runner.py
config.source_file_num236
config.num_gpus1
config.slice_idx0
config.drop_flow_view
config.last_found_was_fusedfalse
config.count_sum_greater_than_10
config.count_sum_less_than_10
config.tiny_op_threshold0.05
config.max_generated_depth1
config.remove_slice_argstrue
config.write_model_ops_to_filetrue
config.op_idx1
metadata.modelRNN-T
metadata.datasetLibriSpeech
metadata.batch_size64
metadata.optimizerAdam
trace_event_count216681
pytorch_version1.13.1+cu117
gpu_modelNVIDIA GeForce RTX 3090
gpu_cuda_version11.7
hotline_traces_trace_disk_size193.6 MB
hotline_annotation_count185
processed_datetime30/05/2023 01:40:41
runtime_without_profiling1.69 s ±0.4%
runtime_with_profiling1.77 s ±0.3%
runtime_profiling_overhead_factor0.05× slower
hotline_analysis_time21.3 s
runtime1671969000
runtime_str1.67 s
start_timestamp01:39:50.604.604527
recommendationsNone
bound_byCPU-Bound
longest_rescpu2
resourceNameundefined
slice_count2