萌新刚学cuda,有会cuda的大佬能指点一下吗?
import numpy as np
from numba import cuda
@cuda.jit
def is_prime_gpu(n, results):
idx = cuda.grid(1) # 获取全局线程索引
if idx < n.size:
number = n[idx]
if number <= 1:
results[idx] = 0
return
is_prime = 1
for j in range(2, int(number ** 0.5) + 1):
if number % j == 0:
is_prime = 0
break
results[idx] = is_prime
def check_primes_on_gpu(numbers):
n = np.array(numbers)
results = np.zeros(n.shape, dtype=np.int32)
# 将数据传输到设备
d_numbers = cuda.to_device(n)
d_results = cuda.to_device(results)
# 定义块和网格的大小
threads_per_block = 32
blocks_per_grid = (n.size + (threads_per_block - 1)) // threads_per_block
# 启动内核
is_prime_gpu[blocks_per_grid, threads_per_block](d_numbers, d_results)
# 将结果从设备传回主机
d_results.copy_to_host(results)
# 0表示不是质数,1表示是质数
return results
errors:
D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py:536: NumbaPerformanceWarning: Grid size 1 will likely result in GPU under-utilization due to low occupancy.
warn(NumbaPerformanceWarning(msg))
Traceback (most recent call last):
File "C:\Users\Xyt13\Documents\primes.py", line 58, in <module>
if check_primes_on_gpu([s]):
^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Xyt13\Documents\primes.py", line 44, in check_primes_on_gpu
is_prime_gpu[blocks_per_grid, threads_per_block](d_numbers, d_results)
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 539, in __call__
return self.dispatcher.call(args, self.griddim, self.blockdim,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 681, in call
kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 689, in _compile_for_args
return self.compile(tuple(argtypes))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 932, in compile
kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 83, in __init__
cres = compile_cuda(self.py_func, types.void, self.argtypes,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\compiler.py", line 196, in compile_cuda
cres = compiler.compile_extra(typingctx=typingctx,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 744, in compile_extra
return pipeline.compile_extra(func)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 438, in compile_extra
return self._compile_bytecode()
^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 506, in _compile_bytecode
return self._compile_core()
^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 485, in _compile_core
raise e
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 472, in _compile_core
pm.run(self.state)
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 368, in run
raise patched_exception
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 356, in run
self._runPass(idx, pass_inst, state)
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 311, in _runPass
mutated |= check(pss.run_pass, internal_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 273, in check
mangled = func(compiler_state)
^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\typed_passes.py", line 112, in run_pass
typemap, return_type, calltypes, errs = type_inference_stage(
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\typed_passes.py", line 93, in type_inference_stage
errs = infer.propagate(raise_errors=raise_errors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\typeinfer.py", line 1091, in propagate
raise errors[0]
numba.core.errors.TypingError: Failed in cuda mode pipeline (step: nopython frontend)
non-precise type array(pyobject, 1d, C)
During: typing of argument at C:\Users\Xyt13\Documents\primes.py (15)
File "primes.py", line 15:
def isPrime(n):
<source elided>
@cuda.jit
^