python cuda
  • 板块灌水区
  • 楼主xie_yuting
  • 当前回复1
  • 已保存回复1
  • 发布时间2024/11/9 08:21
  • 上次更新2024/11/9 11:38:59
查看原帖
python cuda
1025332
xie_yuting楼主2024/11/9 08:21

萌新刚学cuda,有会cuda的大佬能指点一下吗?

import numpy as np
from numba import cuda
@cuda.jit
def is_prime_gpu(n, results):
    idx = cuda.grid(1)  # 获取全局线程索引
    if idx < n.size:
        number = n[idx]
        if number <= 1:
            results[idx] = 0
            return
        is_prime = 1
        for j in range(2, int(number ** 0.5) + 1):
            if number % j == 0:
                is_prime = 0
                break
        results[idx] = is_prime


def check_primes_on_gpu(numbers):
    n = np.array(numbers)
    results = np.zeros(n.shape, dtype=np.int32)

    # 将数据传输到设备
    d_numbers = cuda.to_device(n)
    d_results = cuda.to_device(results)

    # 定义块和网格的大小
    threads_per_block = 32
    blocks_per_grid = (n.size + (threads_per_block - 1)) // threads_per_block

    # 启动内核
    is_prime_gpu[blocks_per_grid, threads_per_block](d_numbers, d_results)

    # 将结果从设备传回主机
    d_results.copy_to_host(results)

    # 0表示不是质数,1表示是质数
    return results

errors:

D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py:536: NumbaPerformanceWarning: Grid size 1 will likely result in GPU under-utilization due to low occupancy.
  warn(NumbaPerformanceWarning(msg))
Traceback (most recent call last):
  File "C:\Users\Xyt13\Documents\primes.py", line 58, in <module>
    if check_primes_on_gpu([s]):
       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Xyt13\Documents\primes.py", line 44, in check_primes_on_gpu
    is_prime_gpu[blocks_per_grid, threads_per_block](d_numbers, d_results)
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 539, in __call__
    return self.dispatcher.call(args, self.griddim, self.blockdim,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 681, in call
    kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 689, in _compile_for_args
    return self.compile(tuple(argtypes))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 932, in compile
    kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\dispatcher.py", line 83, in __init__
    cres = compile_cuda(self.py_func, types.void, self.argtypes,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\cuda\compiler.py", line 196, in compile_cuda
    cres = compiler.compile_extra(typingctx=typingctx,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 744, in compile_extra
    return pipeline.compile_extra(func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 438, in compile_extra
    return self._compile_bytecode()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 506, in _compile_bytecode
    return self._compile_core()
           ^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 485, in _compile_core
    raise e
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler.py", line 472, in _compile_core
    pm.run(self.state)
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 368, in run
    raise patched_exception
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
              ^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\typed_passes.py", line 112, in run_pass
    typemap, return_type, calltypes, errs = type_inference_stage(
                                            ^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\typed_passes.py", line 93, in type_inference_stage
    errs = infer.propagate(raise_errors=raise_errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\Users\Xyt13\anaconda\Lib\site-packages\numba\core\typeinfer.py", line 1091, in propagate
    raise errors[0]
numba.core.errors.TypingError: Failed in cuda mode pipeline (step: nopython frontend)
non-precise type array(pyobject, 1d, C)
During: typing of argument at C:\Users\Xyt13\Documents\primes.py (15)

File "primes.py", line 15:
def isPrime(n):
    <source elided>

@cuda.jit
^
2024/11/9 08:21
加载中...