diff --git a/unified-runtime/source/adapters/offload/device.cpp b/unified-runtime/source/adapters/offload/device.cpp index b5d4fdc571dd1..6dd086f9c3034 100644 --- a/unified-runtime/source/adapters/offload/device.cpp +++ b/unified-runtime/source/adapters/offload/device.cpp @@ -77,10 +77,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS: return ReturnValue(uint32_t{3}); case UR_DEVICE_INFO_COMPILER_AVAILABLE: + case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT: return ReturnValue(true); // Unimplemented features case UR_DEVICE_INFO_PROGRAM_SET_SPECIALIZATION_CONSTANTS: - case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT: case UR_DEVICE_INFO_USM_POOL_SUPPORT: case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP: case UR_DEVICE_INFO_IMAGE_SUPPORT: diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index 9b5cd9140a0f1..62b91f82afed7 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -93,11 +93,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } -UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( - ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead, - size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList, - const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - +namespace { +ur_result_t doMemcpy(ur_queue_handle_t hQueue, void *DestPtr, + ol_device_handle_t DestDevice, const void *SrcPtr, + ol_device_handle_t SrcDevice, size_t size, bool blocking, + uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { // Ignore wait list for now (void)numEventsInWaitList; (void)phEventWaitList; @@ -105,14 +107,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( ol_event_handle_t EventOut = nullptr; - char *DevPtr = - reinterpret_cast(std::get(hBuffer->Mem).Ptr); - - OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice, - DevPtr + offset, hQueue->OffloadDevice, size, - phEvent ? &EventOut : nullptr)); + OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DestPtr, DestDevice, SrcPtr, + SrcDevice, size, phEvent ? &EventOut : nullptr)); - if (blockingRead) { + if (blocking) { OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue)); } @@ -124,37 +122,63 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( return UR_RESULT_SUCCESS; } +} // namespace + +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( + ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead, + size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + char *DevPtr = + reinterpret_cast(std::get(hBuffer->Mem).Ptr); + + return doMemcpy(hQueue, pDst, Adapter->HostDevice, DevPtr + offset, + hQueue->OffloadDevice, size, blockingRead, + numEventsInWaitList, phEventWaitList, phEvent); +} UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite, size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - - // Ignore wait list for now - (void)numEventsInWaitList; - (void)phEventWaitList; - // - - ol_event_handle_t EventOut = nullptr; - char *DevPtr = reinterpret_cast(std::get(hBuffer->Mem).Ptr); - OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DevPtr + offset, - hQueue->OffloadDevice, pSrc, Adapter->HostDevice, - size, phEvent ? &EventOut : nullptr)); + return doMemcpy(hQueue, DevPtr + offset, hQueue->OffloadDevice, pSrc, + Adapter->HostDevice, size, blockingWrite, numEventsInWaitList, + phEventWaitList, phEvent); +} - if (blockingWrite) { - OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue)); +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( + ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, + bool blockingRead, size_t count, size_t offset, void *pDst, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + void *Ptr; + if (auto Err = urProgramGetGlobalVariablePointer(nullptr, hProgram, name, + nullptr, &Ptr)) { + return Err; } - if (phEvent) { - auto *Event = new ur_event_handle_t_(); - Event->OffloadEvent = EventOut; - *phEvent = Event; + return doMemcpy(hQueue, pDst, Adapter->HostDevice, + reinterpret_cast(Ptr) + offset, + hQueue->OffloadDevice, count, blockingRead, + numEventsInWaitList, phEventWaitList, phEvent); +} + +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite( + ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, + bool blockingWrite, size_t count, size_t offset, const void *pSrc, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + void *Ptr; + if (auto Err = urProgramGetGlobalVariablePointer(nullptr, hProgram, name, + nullptr, &Ptr)) { + return Err; } - return UR_RESULT_SUCCESS; + return doMemcpy(hQueue, reinterpret_cast(Ptr) + offset, + hQueue->OffloadDevice, pSrc, Adapter->HostDevice, count, + blockingWrite, numEventsInWaitList, phEventWaitList, phEvent); } ur_result_t enqueueNoOp(ur_queue_handle_t hQueue, ur_event_handle_t *phEvent) { diff --git a/unified-runtime/source/adapters/offload/program.cpp b/unified-runtime/source/adapters/offload/program.cpp index e889f59ef8402..b2297b2b4bb1d 100644 --- a/unified-runtime/source/adapters/offload/program.cpp +++ b/unified-runtime/source/adapters/offload/program.cpp @@ -29,7 +29,7 @@ namespace { #ifdef UR_CUDA_ENABLED ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext, const uint8_t *Binary, size_t Length, - ur_program_handle_t *phProgram) { + ur_program_handle_t hProgram) { uint8_t *RealBinary; size_t RealLength; CUlinkState State; @@ -48,25 +48,17 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext, fprintf(stderr, "Performed CUDA bin workaround (size = %lu)\n", RealLength); #endif - ur_program_handle_t Program = new ur_program_handle_t_(); auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary, - RealLength, &Program->OffloadProgram); + RealLength, &hProgram->OffloadProgram); // Program owns the linked module now cuLinkDestroy(State); - if (Res != OL_SUCCESS) { - delete Program; - return offloadResultToUR(Res); - } - - *phProgram = Program; - - return UR_RESULT_SUCCESS; + return offloadResultToUR(Res); } #else ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t, const uint8_t *, - size_t, ur_program_handle_t *) { + size_t, ur_program_handle_t) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } #endif @@ -76,7 +68,8 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t, const uint8_t *, UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( ur_context_handle_t hContext, uint32_t numDevices, ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries, - const ur_program_properties_t *, ur_program_handle_t *phProgram) { + const ur_program_properties_t *pProperties, + ur_program_handle_t *phProgram) { if (numDevices > 1) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } @@ -100,24 +93,55 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary( } } + ur_program_handle_t Program = new ur_program_handle_t_{}; + Program->URContext = hContext; + Program->Binary = RealBinary; + Program->BinarySizeInBytes = RealLength; + + // Parse properties + if (pProperties) { + if (pProperties->count > 0 && pProperties->pMetadatas == nullptr) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } else if (pProperties->count == 0 && pProperties->pMetadatas != nullptr) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + + auto Length = pProperties->count; + auto Metadata = pProperties->pMetadatas; + for (size_t i = 0; i < Length; ++i) { + const ur_program_metadata_t MetadataElement = Metadata[i]; + std::string MetadataElementName{MetadataElement.pName}; + + auto [Prefix, Tag] = splitMetadataName(MetadataElementName); + + if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) { + const char *MetadataValPtr = + reinterpret_cast(MetadataElement.value.pData) + + sizeof(std::uint64_t); + const char *MetadataValPtrEnd = + MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t); + Program->GlobalIDMD[Prefix] = + std::string{MetadataValPtr, MetadataValPtrEnd}; + } + } + } + + ur_result_t Res; ol_platform_backend_t Backend; olGetPlatformInfo(phDevices[0]->Platform->OffloadPlatform, OL_PLATFORM_INFO_BACKEND, sizeof(Backend), &Backend); if (Backend == OL_PLATFORM_BACKEND_CUDA) { - return ProgramCreateCudaWorkaround(hContext, RealBinary, RealLength, - phProgram); + Res = + ProgramCreateCudaWorkaround(hContext, RealBinary, RealLength, Program); + } else { + Res = offloadResultToUR(olCreateProgram(hContext->Device->OffloadDevice, + RealBinary, RealLength, + &Program->OffloadProgram)); } - ur_program_handle_t Program = new ur_program_handle_t_{}; - Program->URContext = hContext; - Program->Binary = RealBinary; - Program->BinarySizeInBytes = RealLength; - auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary, - RealLength, &Program->OffloadProgram); - - if (Res != OL_SUCCESS) { + if (Res != UR_RESULT_SUCCESS) { delete Program; - return offloadResultToUR(Res); + return Res; } *phProgram = Program; @@ -240,3 +264,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants( ur_program_handle_t, uint32_t, const ur_specialization_constant_info_t *) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } + +UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer( + ur_device_handle_t, ur_program_handle_t hProgram, + const char *pGlobalVariableName, size_t *pGlobalVariableSizeRet, + void **ppGlobalVariablePointerRet) { + auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(pGlobalVariableName); + if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end()) + return UR_RESULT_ERROR_INVALID_VALUE; + std::string DeviceGlobalName = DeviceGlobalNameIt->second; + + ol_symbol_handle_t Symbol; + auto Err = olGetSymbol(hProgram->OffloadProgram, DeviceGlobalName.c_str(), + OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Symbol); + if (Err && Err->Code == OL_ERRC_NOT_FOUND) { + return UR_RESULT_ERROR_INVALID_VALUE; + } + OL_RETURN_ON_ERR(Err); + + if (pGlobalVariableSizeRet) { + OL_RETURN_ON_ERR(olGetSymbolInfo(Symbol, + OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, + sizeof(size_t), pGlobalVariableSizeRet)); + } + OL_RETURN_ON_ERR(olGetSymbolInfo(Symbol, + OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, + sizeof(void *), ppGlobalVariablePointerRet)); + + return UR_RESULT_SUCCESS; +} diff --git a/unified-runtime/source/adapters/offload/program.hpp b/unified-runtime/source/adapters/offload/program.hpp index abd0f41c18ba9..0f93fe68d93d0 100644 --- a/unified-runtime/source/adapters/offload/program.hpp +++ b/unified-runtime/source/adapters/offload/program.hpp @@ -20,4 +20,6 @@ struct ur_program_handle_t_ : RefCounted { ur_context_handle_t URContext; const uint8_t *Binary; size_t BinarySizeInBytes; + // A mapping from mangled global names -> names in the binary + std::unordered_map GlobalIDMD; }; diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index 3c6a9a86d4def..02de9df99fddc 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -92,7 +92,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable( pDdiTable->pfnCreateWithNativeHandle = urProgramCreateWithNativeHandle; pDdiTable->pfnGetBuildInfo = nullptr; pDdiTable->pfnGetFunctionPointer = nullptr; - pDdiTable->pfnGetGlobalVariablePointer = nullptr; + pDdiTable->pfnGetGlobalVariablePointer = urProgramGetGlobalVariablePointer; pDdiTable->pfnGetInfo = urProgramGetInfo; pDdiTable->pfnGetNativeHandle = urProgramGetNativeHandle; pDdiTable->pfnLink = nullptr; @@ -168,8 +168,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( if (UR_RESULT_SUCCESS != result) { return result; } - pDdiTable->pfnDeviceGlobalVariableRead = nullptr; - pDdiTable->pfnDeviceGlobalVariableWrite = nullptr; + pDdiTable->pfnDeviceGlobalVariableRead = urEnqueueDeviceGlobalVariableRead; + pDdiTable->pfnDeviceGlobalVariableWrite = urEnqueueDeviceGlobalVariableWrite; pDdiTable->pfnEventsWait = nullptr; pDdiTable->pfnEventsWaitWithBarrier = nullptr; pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch;