Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document the most important API functions #495

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,42 @@ typedef enum {
DEM_ForwardModeSplit = 4,
} CDerivativeMode;

/// Generates a new function based on the input, which uses forward-mode AD.
///
Copy link
Member

@wsmoses wsmoses Feb 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to document the formal function: https://enzyme.mit.edu/doxygen/classEnzymeLogic.html#a8e6ba5c54950f1ebcdfcbd7ce2a16083

I'm slightly inclined to have the C API documentation be light, and instead forward one to see the C++ function which it wraps (this applies to all of the functions in the C API)

/// @param mode: DEM_ForwardMode and DEM_ForwardModeSplit
/// @param retType: DFT_DUP_ARG, DFT_CONSTANT
/// @param constant_args: pointing to combinations of DFT_DUP_ARG and
/// DFT_CONSTANT
/// @param width: integer n >= 1. The generated functions expects n additional
/// inputs for each input marked as DFT_DUP_ARG.
/// @param returnValue: Whether to return the primary return value.
/// The similar dret_used is not available here / implicitely set to true, as it
/// is the only place where we return gradient information.
/// DEM_ReverseModeCombined. Otherwise, should be set to the amount of
/// input params of `todiff`, which might change between the forward and the
/// reverse pass.
LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
EnzymeTypeAnalysisRef TA, uint8_t returnValue, CDerivativeMode mode,
unsigned width, LLVMTypeRef additionalArg, struct CFnTypeInfo typeInfo,
uint8_t *_uncacheable_args, size_t uncacheable_args_size);

/// Generates a new function based on the input, which uses reverse-mode AD.
///
/// Based on the
/// @param retType: When returning f32/f64 we might use DFT_OUT_DIFF.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use llvm type names here: float/double. We also would want to state the general rule that out_diff is used when the value being differentiated is being passed as a register (examples being float double, a vector of these, an llvm.struct of these, etc) rather than through memory (e.g. a pointer)

/// When returning anything else, one should use DFT_CONSTANT
/// @param width: currently only supporting width=1 here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't add this as a comment here, and instead describe what width should do.

/// @param mode: Should be any of DEM_ReverseModePrimal,
/// DEM_ReverseModeGradient, DEM_ReverseModeCombined.
/// @param augmented: Pass a nullptr, in case of TODO
/// @param AtomicAdd: Enables thread-safe accumulates for being called within
/// a parallel context.
/// @param uncacheable_args_size: Should be set to 0, as long as using
/// DEM_ReverseModeCombined. Otherwise, should be set to the amount of
/// input params of `todiff`, which might change between the forward and the
/// reverse pass.
LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
Expand All @@ -133,6 +162,17 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
size_t uncacheable_args_size, EnzymeAugmentedReturnPtr augmented,
uint8_t AtomicAdd);

/// Generates an augmented forward function based on the input.
///
/// Cached information will be stored on a Tape.
/// Will usually be used in combination with EnzymeCreatePrimalAndGradient for
/// reverse-mode-split AD.
/// @param retType: DFT_DUP_ARG, DFT_CONSTANT
/// @param constant_args: pointing to combinations of DFT_DUP_ARG and
/// DFT_CONSTANT
/// @param forceAnonymousTape: TODO
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per https://enzyme.mit.edu/doxygen/classEnzymeLogic.html#a8e6ba5c54950f1ebcdfcbd7ce2a16083, forceAnonymousTape forces the tape to be an i8* rather than the true tape.

/// @param AtomicAdd: Enables thread-safe accumulates for being called within
/// a parallel context.
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
CDIFFE_TYPE *constant_args, size_t constant_args_size,
Expand All @@ -144,13 +184,25 @@ typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/,
CTypeTreeRef * /*args*/,
struct IntList * /*knownValues*/,
size_t /*numArgs*/, LLVMValueRef);

/// Creates a new TypeAnalysis, to be used by the three EnzymeCreate functions.
///
/// Usually the TypeAnalysisRef will be reused for a complete compilation
/// process.
EnzymeTypeAnalysisRef CreateTypeAnalysis(EnzymeLogicRef Log,
char **customRuleNames,
CustomRuleType *customRules,
size_t numRules);
void ClearTypeAnalysis(EnzymeTypeAnalysisRef);
void FreeTypeAnalysis(EnzymeTypeAnalysisRef);

/// This will be used by enzyme to manage some internal settings.
///
/// Enzyme requires two optimization runs for best performance, one before AD
/// and one after. The second one can be applied automatically be setting
/// `PostOpt` to 1. Usually the LogicRef will be reused for a complete
/// compilation process.
/// @param PostOpt Should be set to 1, except for debug builds.
EnzymeLogicRef CreateEnzymeLogic(uint8_t PostOpt);
void ClearEnzymeLogic(EnzymeLogicRef);
void FreeEnzymeLogic(EnzymeLogicRef);
Expand Down