1r"""Run a submission on a single workload.
2
3Example command:
4
5python3 submission_runner.py \
6 --workload=mnist \
7 --framework=jax \
8 --submission_path=reference_submissions/mnist/mnist_jax/submission.py \
9 --tuning_ruleset=external \
10 --tuning_search_space=reference_submissions/mnist/tuning_search_space.json \
11 --num_tuning_trials=3
12"""
13import importlib
14import inspect
15import json
16import os
17import struct
18import time
19from typing import Optional, Tuple
20
21from absl import app
22from absl import flags
23from absl import logging
24import tensorflow as tf
25import torch
26import torch.distributed as dist
27
28import datetime
29
30from algorithmic_efficiency import halton
31from algorithmic_efficiency import random_utils as prng
32from algorithmic_efficiency import spec
33
34
35
36
37tf.config.experimental.set_visible_devices([], 'GPU')
38
39
40BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/'
41
42
43WORKLOADS = {
44 'cifar': {
45 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload'
46 },
47 'criteo1tb': {
48 'workload_path': 'criteo1tb/criteo1tb',
49 'workload_class_name': 'Criteo1TbDlrmSmallWorkload'
50 },
51 'fastmri': {
52 'workload_path': 'fastmri/fastmri',
53 'workload_class_name': 'FastMRIWorkload'
54 },
55 'imagenet_resnet': {
56 'workload_path': 'imagenet_resnet/imagenet',
57 'workload_class_name': 'ImagenetResNetWorkload'
58 },
59 'imagenet_vit': {
60 'workload_path': 'imagenet_vit/imagenet',
61 'workload_class_name': 'ImagenetVitWorkload'
62 },
63 'librispeech': {
64 'workload_path': 'librispeech/librispeech',
65 'workload_class_name': 'LibriSpeechWorkload'
66 },
67 'mnist': {
68 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload'
69 },
70 'ogbg': {
71 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'
72 },
73 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'},
74}
75
76flags.DEFINE_string(
77 'submission_path',
78 'reference_submissions/mnist/mnist_jax/submission.py',
79 'The relative path of the Python file containing submission functions. '
80 'NOTE: the submission dir must have an __init__.py file!')
81flags.DEFINE_string(
82 'workload',
83 'mnist',
84 help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}'
85)
86flags.DEFINE_enum(
87 'tuning_ruleset',
88 'external',
89 enum_values=['external', 'self'],
90 help='Which tuning ruleset to use.')
91flags.DEFINE_string(
92 'tuning_search_space',
93 'reference_submissions/mnist/tuning_search_space.json',
94 'The path to the JSON file describing the external tuning search space.')
95flags.DEFINE_integer('num_tuning_trials',
96 20,
97 'The number of external hyperparameter trials to run.')
98flags.DEFINE_string('data_dir', '~/tensorflow_datasets/', 'Dataset location')
99flags.DEFINE_enum(
100 'framework',
101 None,
102 enum_values=['jax', 'pytorch'],
103 help='Whether to use Jax or Pytorch for the submission. Controls among '
104 'other things if the Jax or Numpy RNG library is used for RNG.')
105
106FLAGS = flags.FLAGS
107
108
109def convert_filepath_to_module(path: str):
110 base, extension = os.path.splitext(path)
111
112 if extension != '.py':
113 raise ValueError(f'Path: {path} must be a python file (*.py)')
114
115 return base.replace('/', '.')
116
117
118def import_workload(workload_path: str,
119 workload_class_name: str,
120 return_class=False) -> spec.Workload:
121 """Import and add the workload to the registry.
122
123 This importlib loading is nice to have because it allows runners to avoid
124 installing the dependencies of all the supported frameworks. For example, if
125 a submitter only wants to write Jax code, the try/except below will catch
126 the import errors caused if they do not have the PyTorch dependencies
127 installed on their system.
128
129 Args:
130 workload_path: the path to the `workload.py` file to load.
131 workload_class_name: the name of the Workload class that implements the
132 `Workload` abstract class in `spec.py`.
133 return_class: if true, then the workload class is returned instead of the
134 instantiated object. Useful for testing when methods need to be overriden.
135 """
136
137
138 workload_path = convert_filepath_to_module(workload_path)
139
140
141 workload_module = importlib.import_module(workload_path)
142
143 workload_module_members = inspect.getmembers(workload_module)
144 workload_class = None
145 for name, value in workload_module_members:
146 if name == workload_class_name:
147 workload_class = value
148 break
149 if workload_class is None:
150 raise ValueError(
151 f'Could not find member {workload_class_name} in {workload_path}. '
152 'Make sure the Workload class is spelled correctly and defined in '
153 'the top scope of the module.')
154 if return_class:
155 return workload_class
156 return workload_class()
157
158
159
160
161def train_once(workload: spec.Workload,
162 global_batch_size: int,
163 data_dir: str,
164 init_optimizer_state: spec.InitOptimizerFn,
165 update_params: spec.UpdateParamsFn,
166 data_selection: spec.DataSelectionFn,
167 hyperparameters: Optional[spec.Hyperparameters],
168 rng: spec.RandomState) -> Tuple[spec.Timing, spec.Steps]:
169 data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
170
171
172 logging.info('Initializing dataset.')
173 input_queue = workload.build_input_queue(
174 data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size)
175 logging.info('Initializing model.')
176 model_params, model_state = workload.init_model_fn(model_init_rng)
177 logging.info('Initializing optimizer.')
178 optimizer_state = init_optimizer_state(workload,
179 model_params,
180 model_state,
181 hyperparameters,
182 opt_init_rng)
183
184
185 goal_reached = False
186 is_time_remaining = True
187 last_eval_time = 0
188 accumulated_submission_time = 0
189 eval_results = []
190 global_step = 0
191 training_complete = False
192 global_start_time = time.time()
193
194
195
196 import hotline
197 from IPython import embed
198 import datetime
199
200 last_time = datetime.datetime.now()
201 print(last_time)
202 model_params = torch.nn.DataParallel(model_params)
203
204 quick_run = os.environ.get('HOTLINE_QUICK_RUN')
205 if False:
206 wait = 2
207 warmup = 2
208 active = 1
209 else:
210 wait = 20
211 warmup = 19
212 active = 1
213
214
215
216 max_steps = wait + warmup + active
217
218 metadata = {
219 'model': 'RNN-T',
220 'dataset': 'LibriSpeech',
221 'batch_size': global_batch_size,
222 'optimizer': 'Adam',
223 'runtime': [],
224 }
225
226
227
228 torch_profiler = torch.profiler.profile(
229 activities=[
230 torch.profiler.ProfilerActivity.CPU,
231 torch.profiler.ProfilerActivity.CUDA],
232 schedule=torch.profiler.schedule(
233 wait=wait,
234 warmup=warmup,
235 active=active),
236 on_trace_ready=hotline.analyze(
237 model_params,
238 input_queue,
239 run_name='RNN',
240 test_accuracy=True,
241 output_dir='/home/dans/cpath',
242 metadata=metadata,
243 ),
244 record_shapes=False,
245 profile_memory=False,
246 with_stack=False
247 )
248
249
250 logging.info('Starting training loop.')
251 while (is_time_remaining and not goal_reached and not training_complete):
252
253 step_rng = prng.fold_in(rng, global_step)
254 data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)
255 start_time = time.time()
256 with hotline.annotate('Load Data'):
257 batch = data_selection(workload,
258 input_queue,
259 optimizer_state,
260 model_params,
261 hyperparameters,
262 global_step,
263 data_select_rng)
264 try:
265 optimizer_state, model_params, model_state = update_params(
266 workload=workload,
267 current_param_container=model_params,
268 current_params_types=workload.model_params_types,
269 model_state=model_state,
270 hyperparameters=hyperparameters,
271 batch=batch,
272 loss_type=workload.loss_type,
273 optimizer_state=optimizer_state,
274 eval_results=eval_results,
275 global_step=global_step,
276 rng=update_rng)
277
278
279
280
281 this_time = datetime.datetime.now()
282 tdelta = this_time - last_time
283 logging.info(f'tdelta: {tdelta}')
284 metadata['runtime'].append(tdelta)
285 last_time = this_time
286 logging.info(f'global_step: {global_step}\n')
287 torch_profiler.step()
288 hotline.annotate.step()
289 if global_step >= max_steps:
290 import sys
291 sys.exit(0)
292 global_step += 1
293 continue
294
295
296
297
298 except spec.TrainingCompleteError:
299 training_complete = True
300 global_step += 1
301 current_time = time.time()
302 accumulated_submission_time += current_time - start_time
303 is_time_remaining = (
304 accumulated_submission_time < workload.max_allowed_runtime_sec)
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320 metrics = {'eval_results': eval_results, 'global_step': global_step}
321 return accumulated_submission_time, metrics
322
323
324def score_submission_on_workload(workload: spec.Workload,
325 workload_name: str,
326 submission_path: str,
327 data_dir: str,
328 tuning_ruleset: str,
329 tuning_search_space: Optional[str] = None,
330 num_tuning_trials: Optional[int] = None):
331
332 submission_module_path = convert_filepath_to_module(submission_path)
333 submission_module = importlib.import_module(submission_module_path)
334
335 init_optimizer_state = submission_module.init_optimizer_state
336 update_params = submission_module.update_params
337 data_selection = submission_module.data_selection
338 global_batch_size = submission_module.get_batch_size(workload_name)
339
340 if tuning_ruleset == 'external':
341
342
343
344 if tuning_search_space is None:
345 raise ValueError(
346 'Must provide a tuning search space JSON file when using external '
347 'tuning.')
348 with open(tuning_search_space, 'r', encoding='UTF-8') as search_space_file:
349 tuning_search_space = halton.generate_search(
350 json.load(search_space_file), num_tuning_trials)
351 all_timings = []
352 all_metrics = []
353 for hi, hyperparameters in enumerate(tuning_search_space):
354
355 rng_seed = struct.unpack('I', os.urandom(4))[0]
356 rng = prng.PRNGKey(rng_seed)
357
358
359
360
361
362
363 rng, _ = prng.split(rng, 2)
364 logging.info('--- Tuning run %d/%d ---', hi + 1, num_tuning_trials)
365 timing, metrics = train_once(workload, global_batch_size, data_dir,
366 init_optimizer_state, update_params,
367 data_selection, hyperparameters, rng)
368 all_timings.append(timing)
369 all_metrics.append(metrics)
370 score = min(all_timings)
371 for ti in range(num_tuning_trials):
372 logging.info('Tuning trial %d/%d', ti + 1, num_tuning_trials)
373 logging.info('Hyperparameters: %s', tuning_search_space[ti])
374 logging.info('Metrics: %s', all_metrics[ti])
375 logging.info('Timing: %s', all_timings[ti])
376 logging.info('=' * 20)
377 else:
378 rng_seed = struct.unpack('q', os.urandom(8))[0]
379 rng = prng.PRNGKey(rng_seed)
380
381
382 score, _ = train_once(workload, global_batch_size, init_optimizer_state,
383 update_params, data_selection, None, rng)
384
385 return score
386
387
388def main(_):
389
390 use_pytorch_ddp = 'LOCAL_RANK' in os.environ
391 if FLAGS.framework == 'pytorch':
392
393
394 torch.backends.cudnn.benchmark = True
395
396 if use_pytorch_ddp:
397 rank = int(os.environ['LOCAL_RANK'])
398 torch.cuda.set_device(rank)
399
400 if rank != 0:
401
402 def logging_pass(*args):
403 pass
404
405 logging.info = logging_pass
406
407 dist.init_process_group('nccl')
408
409 workload_metadata = WORKLOADS[FLAGS.workload]
410
411 workload_metadata['workload_path'] = os.path.join(
412 BASE_WORKLOADS_DIR,
413 workload_metadata['workload_path'] + '_' + FLAGS.framework,
414 'workload.py')
415 workload = import_workload(
416 workload_path=workload_metadata['workload_path'],
417 workload_class_name=workload_metadata['workload_class_name'])
418
419 score = score_submission_on_workload(workload,
420 FLAGS.workload,
421 FLAGS.submission_path,
422 FLAGS.data_dir,
423 FLAGS.tuning_ruleset,
424 FLAGS.tuning_search_space,
425 FLAGS.num_tuning_trials)
426 logging.info('Final %s score: %f', FLAGS.workload, score)
427
428 if use_pytorch_ddp:
429
430 dist.destroy_process_group()
431
432
433if __name__ == '__main__':
434 app.run(main)
435