Skip to content
Snippets Groups Projects
Unverified Commit f1f7b210 authored by Shenghang Tsai's avatar Shenghang Tsai Committed by GitHub
Browse files

refine XLA build (#3377)


* use tf protobuf headers

* add readme

* rm patch

* add inc file

* fix rare grpc build bug

* use different headers

* fix file path

* fix cublas conflict with tf

* use cublas dlib for xla

Co-authored-by: default avatartsai <caishenghang@1f-dev.kbaeegfb1x0ubnoznzequyxzve.bx.internal.cloudapp.net>
parent a309c1d2
No related branches found
No related tags found
No related merge requests found
......@@ -55,15 +55,28 @@ if (BUILD_CUDA)
break()
endif()
endforeach()
if(EXISTS ${cublas_lib_dir}/libcublas_static.a)
if(EXISTS ${cublas_lib_dir}/libcublasLt_static.a)
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublasLt_static.a)
if (WITH_XLA)
if(EXISTS ${cublas_lib_dir}/libcublas.so AND EXISTS ${cublas_lib_dir}/libcublasLt.so)
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublasLt.so)
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublas.so)
elseif(EXISTS ${cublas_lib_dir}/libcublas.so)
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublas.so)
elseif(EXISTS ${cuda_lib_dir}/libcublas.so)
list(APPEND CUDA_LIBRARIES ${cuda_lib_dir}/libcublas.so)
else()
message(FATAL_ERROR "cuda lib not found: ${cublas_lib_dir}/libcublas.so or ${cuda_lib_dir}/libcublas.so")
endif()
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublas_static.a)
elseif(EXISTS ${cuda_lib_dir}/libcublas_static.a)
list(APPEND CUDA_LIBRARIES ${cuda_lib_dir}/libcublas_static.a)
else()
message(FATAL_ERROR "cuda lib not found: ${cublas_lib_dir}/libcublas_static.a or ${cuda_lib_dir}/libcublas_static.a")
if(EXISTS ${cublas_lib_dir}/libcublas_static.a AND EXISTS ${cublas_lib_dir}/libcublasLt_static.a)
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublasLt_static.a)
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublas_static.a)
elseif(EXISTS ${cublas_lib_dir}/libcublas_static.a)
list(APPEND CUDA_LIBRARIES ${cublas_lib_dir}/libcublas_static.a)
elseif(EXISTS ${cuda_lib_dir}/libcublas_static.a)
list(APPEND CUDA_LIBRARIES ${cuda_lib_dir}/libcublas_static.a)
else()
message(FATAL_ERROR "cuda lib not found: ${cublas_lib_dir}/libcublas_static.a or ${cuda_lib_dir}/libcublas_static.a")
endif()
endif()
find_package(CUDNN REQUIRED)
endif()
......@@ -161,11 +174,6 @@ if (BUILD_CUDA)
include(cub)
include(nccl)
if (WITH_XLA)
# Fix conflicts between tensorflow cublas dso and oneflow static cublas.
# TODO(hjchen2) Should commit a issue about this fix.
list(APPEND oneflow_third_party_libs -Wl,--whole-archive ${cuda_lib_dir}/libcublas_static.a -Wl,--no-whole-archive)
endif()
list(APPEND oneflow_third_party_libs ${CUDA_LIBRARIES})
list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES})
list(APPEND oneflow_third_party_libs ${NCCL_STATIC_LIBRARIES})
......
......@@ -29,7 +29,7 @@ if(THIRD_PARTY)
ExternalProject_Add(grpc
PREFIX grpc
DEPENDS protobuf zlib
DEPENDS protobuf zlib zlib_copy_headers_to_destination
URL ${GRPC_URL}
UPDATE_COMMAND ""
BUILD_IN_SOURCE 1
......
google/protobuf/any.pb.h
google/protobuf/timestamp.pb.h
google/protobuf/map_field_lite.h
google/protobuf/any.h
google/protobuf/test_util_lite.h
google/protobuf/generated_enum_reflection.h
google/protobuf/extension_set.h
google/protobuf/message.h
google/protobuf/map_entry.h
google/protobuf/descriptor.pb.h
google/protobuf/extension_set_inl.h
google/protobuf/repeated_field.h
google/protobuf/map.h
google/protobuf/api.pb.h
google/protobuf/map_field.h
google/protobuf/reflection.h
google/protobuf/map_test_util.h
google/protobuf/generated_message_reflection.h
google/protobuf/map_entry_lite.h
google/protobuf/arenastring.h
google/protobuf/reflection_internal.h
google/protobuf/generated_message_table_driven.h
google/protobuf/generated_message_util.h
google/protobuf/has_bits.h
google/protobuf/test_util2.h
google/protobuf/unknown_field_set.h
google/protobuf/message_lite.h
google/protobuf/implicit_weak_message.h
google/protobuf/package_info.h
google/protobuf/arena_test_util.h
google/protobuf/descriptor_database.h
google/protobuf/empty.pb.h
google/protobuf/wire_format.h
google/protobuf/dynamic_message.h
google/protobuf/reflection_ops.h
google/protobuf/map_type_handler.h
google/protobuf/wire_format_lite.h
google/protobuf/map_test_util_impl.h
google/protobuf/text_format.h
google/protobuf/arena_impl.h
google/protobuf/metadata.h
google/protobuf/map_field_inl.h
google/protobuf/test_util.h
google/protobuf/generated_enum_util.h
google/protobuf/map_lite_test_util.h
google/protobuf/duration.pb.h
google/protobuf/struct.pb.h
google/protobuf/port.h
google/protobuf/parse_context.h
google/protobuf/inlined_string_field.h
google/protobuf/source_context.pb.h
google/protobuf/generated_message_table_driven_lite.h
google/protobuf/wrappers.pb.h
google/protobuf/metadata_lite.h
google/protobuf/type.pb.h
google/protobuf/descriptor.h
google/protobuf/field_mask.pb.h
google/protobuf/arena.h
google/protobuf/service.h
google/protobuf/util/field_comparator.h
google/protobuf/util/time_util.h
google/protobuf/util/package_info.h
google/protobuf/util/type_resolver.h
google/protobuf/util/json_util.h
google/protobuf/util/delimited_message_util.h
google/protobuf/util/field_mask_util.h
google/protobuf/util/type_resolver_util.h
google/protobuf/util/message_differencer.h
google/protobuf/util/internal/structured_objectwriter.h
google/protobuf/util/internal/constants.h
google/protobuf/util/internal/field_mask_utility.h
google/protobuf/util/internal/expecting_objectwriter.h
google/protobuf/util/internal/protostream_objectwriter.h
google/protobuf/util/internal/object_source.h
google/protobuf/util/internal/utility.h
google/protobuf/util/internal/type_info_test_helper.h
google/protobuf/util/internal/json_escaping.h
google/protobuf/util/internal/json_stream_parser.h
google/protobuf/util/internal/object_location_tracker.h
google/protobuf/util/internal/location_tracker.h
google/protobuf/util/internal/protostream_objectsource.h
google/protobuf/util/internal/datapiece.h
google/protobuf/util/internal/default_value_objectwriter.h
google/protobuf/util/internal/type_info.h
google/protobuf/util/internal/error_listener.h
google/protobuf/util/internal/proto_writer.h
google/protobuf/util/internal/mock_error_listener.h
google/protobuf/util/internal/object_writer.h
google/protobuf/util/internal/json_objectwriter.h
google/protobuf/testing/file.h
google/protobuf/testing/googletest.h
google/protobuf/compiler/zip_writer.h
google/protobuf/compiler/plugin.pb.h
google/protobuf/compiler/scc.h
google/protobuf/compiler/subprocess.h
google/protobuf/compiler/package_info.h
google/protobuf/compiler/code_generator.h
google/protobuf/compiler/annotation_test_util.h
google/protobuf/compiler/plugin.h
google/protobuf/compiler/importer.h
google/protobuf/compiler/mock_code_generator.h
google/protobuf/compiler/command_line_interface.h
google/protobuf/compiler/parser.h
google/protobuf/compiler/csharp/csharp_repeated_primitive_field.h
google/protobuf/compiler/csharp/csharp_doc_comment.h
google/protobuf/compiler/csharp/csharp_message.h
google/protobuf/compiler/csharp/csharp_enum.h
google/protobuf/compiler/csharp/csharp_generator.h
google/protobuf/compiler/csharp/csharp_options.h
google/protobuf/compiler/csharp/csharp_repeated_message_field.h
google/protobuf/compiler/csharp/csharp_reflection_class.h
google/protobuf/compiler/csharp/csharp_map_field.h
google/protobuf/compiler/csharp/csharp_wrapper_field.h
google/protobuf/compiler/csharp/csharp_source_generator_base.h
google/protobuf/compiler/csharp/csharp_message_field.h
google/protobuf/compiler/csharp/csharp_enum_field.h
google/protobuf/compiler/csharp/csharp_primitive_field.h
google/protobuf/compiler/csharp/csharp_helpers.h
google/protobuf/compiler/csharp/csharp_field_base.h
google/protobuf/compiler/csharp/csharp_repeated_enum_field.h
google/protobuf/compiler/csharp/csharp_names.h
google/protobuf/compiler/cpp/cpp_options.h
google/protobuf/compiler/cpp/cpp_primitive_field.h
google/protobuf/compiler/cpp/cpp_enum_field.h
google/protobuf/compiler/cpp/cpp_enum.h
google/protobuf/compiler/cpp/cpp_padding_optimizer.h
google/protobuf/compiler/cpp/cpp_map_field.h
google/protobuf/compiler/cpp/cpp_extension.h
google/protobuf/compiler/cpp/cpp_message.h
google/protobuf/compiler/cpp/cpp_string_field.h
google/protobuf/compiler/cpp/cpp_unittest.h
google/protobuf/compiler/cpp/cpp_field.h
google/protobuf/compiler/cpp/cpp_file.h
google/protobuf/compiler/cpp/cpp_message_field.h
google/protobuf/compiler/cpp/cpp_generator.h
google/protobuf/compiler/cpp/cpp_service.h
google/protobuf/compiler/cpp/cpp_message_layout_helper.h
google/protobuf/compiler/cpp/cpp_helpers.h
google/protobuf/compiler/objectivec/objectivec_file.h
google/protobuf/compiler/objectivec/objectivec_field.h
google/protobuf/compiler/objectivec/objectivec_map_field.h
google/protobuf/compiler/objectivec/objectivec_enum_field.h
google/protobuf/compiler/objectivec/objectivec_enum.h
google/protobuf/compiler/objectivec/objectivec_nsobject_methods.h
google/protobuf/compiler/objectivec/objectivec_message.h
google/protobuf/compiler/objectivec/objectivec_extension.h
google/protobuf/compiler/objectivec/objectivec_primitive_field.h
google/protobuf/compiler/objectivec/objectivec_oneof.h
google/protobuf/compiler/objectivec/objectivec_message_field.h
google/protobuf/compiler/objectivec/objectivec_helpers.h
google/protobuf/compiler/objectivec/objectivec_generator.h
google/protobuf/compiler/js/js_generator.h
google/protobuf/compiler/js/well_known_types_embed.h
google/protobuf/compiler/ruby/ruby_generator.h
google/protobuf/compiler/python/python_generator.h
google/protobuf/compiler/php/php_generator.h
google/protobuf/compiler/java/java_generator.h
google/protobuf/compiler/java/java_file.h
google/protobuf/compiler/java/java_name_resolver.h
google/protobuf/compiler/java/java_generator_factory.h
google/protobuf/compiler/java/java_extension.h
google/protobuf/compiler/java/java_doc_comment.h
google/protobuf/compiler/java/java_message_builder.h
google/protobuf/compiler/java/java_message_lite.h
google/protobuf/compiler/java/java_map_field.h
google/protobuf/compiler/java/java_message_field.h
google/protobuf/compiler/java/java_message_field_lite.h
google/protobuf/compiler/java/java_primitive_field.h
google/protobuf/compiler/java/java_enum_lite.h
google/protobuf/compiler/java/java_string_field_lite.h
google/protobuf/compiler/java/java_extension_lite.h
google/protobuf/compiler/java/java_options.h
google/protobuf/compiler/java/java_enum_field.h
google/protobuf/compiler/java/java_field.h
google/protobuf/compiler/java/java_map_field_lite.h
google/protobuf/compiler/java/java_primitive_field_lite.h
google/protobuf/compiler/java/java_names.h
google/protobuf/compiler/java/java_enum.h
google/protobuf/compiler/java/java_string_field.h
google/protobuf/compiler/java/java_enum_field_lite.h
google/protobuf/compiler/java/java_context.h
google/protobuf/compiler/java/java_helpers.h
google/protobuf/compiler/java/java_shared_code_generator.h
google/protobuf/compiler/java/java_service.h
google/protobuf/compiler/java/java_message_builder_lite.h
google/protobuf/compiler/java/java_message.h
google/protobuf/io/zero_copy_stream_impl_lite.h
google/protobuf/io/strtod.h
google/protobuf/io/printer.h
google/protobuf/io/io_win32.h
google/protobuf/io/package_info.h
google/protobuf/io/gzip_stream.h
google/protobuf/io/coded_stream.h
google/protobuf/io/zero_copy_stream_impl.h
google/protobuf/io/tokenizer.h
google/protobuf/io/zero_copy_stream.h
google/protobuf/io/coded_stream_inl.h
google/protobuf/stubs/int128.h
google/protobuf/stubs/mathlimits.h
google/protobuf/stubs/once.h
google/protobuf/stubs/time.h
google/protobuf/stubs/mutex.h
google/protobuf/stubs/strutil.h
google/protobuf/stubs/logging.h
google/protobuf/stubs/callback.h
google/protobuf/stubs/common.h
google/protobuf/stubs/macros.h
google/protobuf/stubs/stringprintf.h
google/protobuf/stubs/fastmem.h
google/protobuf/stubs/substitute.h
google/protobuf/stubs/stringpiece.h
google/protobuf/stubs/stl_util.h
google/protobuf/stubs/hash.h
google/protobuf/stubs/status_macros.h
google/protobuf/stubs/mathutil.h
google/protobuf/stubs/status.h
google/protobuf/stubs/casts.h
google/protobuf/stubs/port.h
google/protobuf/stubs/template_util.h
google/protobuf/stubs/map_util.h
google/protobuf/stubs/platform_macros.h
google/protobuf/stubs/statusor.h
google/protobuf/stubs/bytestream.h
google/protobuf/descriptor.proto
google/protobuf/source_context.proto
google/protobuf/timestamp.proto
google/protobuf/unittest_lite_imports_nonlite.proto
google/protobuf/map_proto2_unittest.proto
google/protobuf/unittest_optimize_for.proto
google/protobuf/unittest_proto3_lite.proto
google/protobuf/unittest_no_generic_services.proto
google/protobuf/unittest_no_arena_lite.proto
google/protobuf/unittest_import.proto
google/protobuf/unittest_lazy_dependencies_enum.proto
google/protobuf/unittest_lazy_dependencies_custom_option.proto
google/protobuf/unittest_custom_options.proto
google/protobuf/unittest_empty.proto
google/protobuf/map_lite_unittest.proto
google/protobuf/unittest_lite.proto
google/protobuf/wrappers.proto
google/protobuf/any.proto
google/protobuf/unittest_proto3_arena.proto
google/protobuf/unittest_proto3_arena_lite.proto
google/protobuf/empty.proto
google/protobuf/field_mask.proto
google/protobuf/unittest_import_public_lite.proto
google/protobuf/unittest_well_known_types.proto
google/protobuf/test_messages_proto3.proto
google/protobuf/unittest_preserve_unknown_enum2.proto
google/protobuf/unittest_arena.proto
google/protobuf/duration.proto
google/protobuf/unittest_preserve_unknown_enum.proto
google/protobuf/unittest_import_lite.proto
google/protobuf/unittest_enormous_descriptor.proto
google/protobuf/unittest_mset_wire_format.proto
google/protobuf/unittest_no_arena.proto
google/protobuf/unittest_embed_optimize_for.proto
google/protobuf/test_messages_proto2.proto
google/protobuf/unittest_proto3.proto
google/protobuf/map_unittest.proto
google/protobuf/unittest_no_arena_import.proto
google/protobuf/unittest.proto
google/protobuf/unittest_mset.proto
google/protobuf/struct.proto
google/protobuf/unittest_drop_unknown_fields.proto
google/protobuf/any_test.proto
google/protobuf/unittest_no_field_presence.proto
google/protobuf/unittest_import_public.proto
google/protobuf/api.proto
google/protobuf/type.proto
google/protobuf/unittest_lazy_dependencies.proto
google/protobuf/util/message_differencer_unittest.proto
google/protobuf/util/json_format.proto
google/protobuf/util/json_format_proto3.proto
google/protobuf/util/internal/testdata/default_value.proto
google/protobuf/util/internal/testdata/books.proto
google/protobuf/util/internal/testdata/default_value_test.proto
google/protobuf/util/internal/testdata/anys.proto
google/protobuf/util/internal/testdata/timestamp_duration.proto
google/protobuf/util/internal/testdata/wrappers.proto
google/protobuf/util/internal/testdata/field_mask.proto
google/protobuf/util/internal/testdata/maps.proto
google/protobuf/util/internal/testdata/proto3.proto
google/protobuf/util/internal/testdata/oneofs.proto
google/protobuf/util/internal/testdata/struct.proto
google/protobuf/compiler/plugin.proto
google/protobuf/compiler/cpp/cpp_test_large_enum_value.proto
google/protobuf/compiler/cpp/cpp_test_bad_identifiers.proto
google/protobuf/compiler/ruby/ruby_generated_pkg_implicit.proto
google/protobuf/compiler/ruby/ruby_generated_pkg_explicit_legacy.proto
google/protobuf/compiler/ruby/ruby_generated_code_proto2.proto
google/protobuf/compiler/ruby/ruby_generated_pkg_explicit.proto
google/protobuf/compiler/ruby/ruby_generated_code.proto
google/protobuf/message_unittest.inc
google/protobuf/port_undef.inc
google/protobuf/test_util.inc
google/protobuf/port_def.inc
google/protobuf/proto3_lite_unittest.inc
google/protobuf/compiler/cpp/cpp_unittest.inc
......@@ -59,7 +59,12 @@ ExternalProject_Add(protobuf
)
# put protobuf includes in the 'THIRD_PARTY_DIR'
add_copy_headers_target(NAME protobuf SRC ${PROTOBUF_SRC_DIR} DST ${PROTOBUF_INCLUDE_DIR} DEPS protobuf INDEX_FILE "${oneflow_cmake_dir}/third_party/header_index/protobuf_headers.txt")
if(WITH_XLA)
add_copy_headers_target(NAME protobuf SRC ${PROTOBUF_SRC_DIR} DST ${PROTOBUF_INCLUDE_DIR} DEPS protobuf INDEX_FILE "${oneflow_cmake_dir}/third_party/header_index/protobuf_xla_headers.txt")
else()
add_copy_headers_target(NAME protobuf SRC ${PROTOBUF_SRC_DIR} DST ${PROTOBUF_INCLUDE_DIR} DEPS protobuf INDEX_FILE "${oneflow_cmake_dir}/third_party/header_index/protobuf_headers.txt")
endif()
# put protobuf librarys in the 'THIRD_PARTY_DIR'
add_custom_target(protobuf_create_library_dir
......
......@@ -25,10 +25,10 @@ endif()
message(STATUS ${TENSORFLOW_BUILD_CMD})
set(TENSORFLOW_PROJECT tensorflow)
set(TENSORFLOW_GIT_URL https://github.com/tensorflow/tensorflow.git)
set(TENSORFLOW_GIT_URL https://github.com/OneFlow-Inc/tensorflow.git)
#set(TENSORFLOW_GIT_TAG master)
set(TENSORFLOW_GIT_TAG 80c04b80ad66bf95aa3f41d72a6bba5e84a99622)
set(TENSORFLOW_SOURCES_DIR ${THIRD_PARTY_SUBMODULE_DIR}/tensorflow)
set(TENSORFLOW_GIT_TAG dea9488e5f05ffcaff7e729f33d475af3a7021ba)
set(TENSORFLOW_SOURCES_DIR ${CMAKE_CURRENT_BINARY_DIR}/tensorflow)
set(TENSORFLOW_SRCS_DIR ${TENSORFLOW_SOURCES_DIR}/src/tensorflow)
set(TENSORFLOW_INC_DIR ${TENSORFLOW_SOURCES_DIR}/src/tensorflow)
......@@ -64,7 +64,6 @@ if (THIRD_PARTY)
PREFIX ${TENSORFLOW_SOURCES_DIR}
GIT_REPOSITORY ${TENSORFLOW_GIT_URL}
GIT_TAG ${TENSORFLOW_GIT_TAG}
PATCH_COMMAND patch -Np1 < ${PATCHES_DIR}/xla.patch
CONFIGURE_COMMAND ""
BUILD_COMMAND cd ${TENSORFLOW_SRCS_DIR} &&
bazel build ${TENSORFLOW_BUILD_CMD} -j 20 //tensorflow/compiler/jit/xla_lib:libxla_core.so
......
......@@ -15,6 +15,17 @@ docker build -f docker/package/manylinux/Dockerfile --build-arg from=nvidia/cuda
docker run --rm -it -v `pwd`:/oneflow-src -w /oneflow-src oneflow:manylinux2014-cuda10.2
```
If you prefer operate inside docker:
```bash
docker run --rm -it -v `pwd`:/oneflow-src -w /oneflow-src oneflow:manylinux2014-cuda10.2 bash
```
```bash
/oneflow-src/docker/package/manylinux/build_wheel.sh --python3.6 --wheel-dir /oneflow-src/wheel-test
```
就会在 docker 镜像里执行 build_wheel.sh 来编译生成 python 3.5 到 python 3.8 的 oneflow manylinux2014 wheel。生成的包在 oneflow 源码目录下的 wheelhouse/ 文件夹内
#### 注意事项
......
......@@ -15,7 +15,7 @@ def glob_by_pattern(pattern):
for x in glob.glob(os.path.join(args.src_path, pattern), recursive=True):
result.append(os.path.relpath(x, args.src_path))
return result
headers = glob_by_pattern('**/*.h') + glob_by_pattern('**/*.hpp') + glob_by_pattern('**/*.cuh') + glob_by_pattern('**/*.proto')
headers = glob_by_pattern('**/*.h') + glob_by_pattern('**/*.hpp') + glob_by_pattern('**/*.cuh') + glob_by_pattern('**/*.proto') + glob_by_pattern('**/*.inc')
with open(args.dst_file, 'w') as f:
for item in headers:
f.write("{}\n".format(item))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment