Getting ptx from Triton
Published:
You can get the ptx of a triton kernel like so: my_kernel.cache[DEVICE_KEY][INPUTS_KEY].asm['ptx']
, where DEVICE_KEY
and INPUTS_KEY
are determined like below.
Case 1 (most likely): If you’ve compiled the kernel only for a single device and a single set of inputs, cache
only has a single key-value mapping, which itself only as a single key-value mapping. So this works:
def value(dict_):
assert len(dict_)==1, 'dict has more than one value' # we're assuming a single env & a single input set
return list(dict_.values())[0] # return that single value
value(value(my_kernel.cache)).asm['ptx']
Case 2: If you’ve compiled for multiple devices or input sets, you need to find DEVICE_KEY
and INPUTS_KEY
. Print my_kernel.cache.keys()
and select the device key you need, the print the keys()
for that again to get the inputs key you want.
That’s it!
- Umer