Skip to content

Virtual functions differentiation with free function doesn't work #2666

@ilagunap

Description

@ilagunap

I am trying to differentiate a virtual function using the Clang plugin. I understand this is a challenging use case for Enzyme, but I am trying to determine whether I am using it correctly or encountering a bug.

Following the documentation, I created a free function that takes the class object as an argument, and I ask Enzyme to differentiate that free function.

The setup is: A base class Tensor, and two derived classes.

When I pass the class object to __enzyme_autodiff as enzyme_const, I get the following compilation error (see lines 50–56 in model.cpp):

error: Enzyme: No create nofree of unknown value
ptr %0
 at context:   call void %8(ptr noundef nonnull align 8 dereferenceable(8) %0, ptr noundef %1, ptr noundef %2, i32 noundef %3, ptr noundef %4) #8
1 error generated.
make: *** [Makefile:15: model.o] Error 1

If instead I pass the class object as enzyme_dup, the code compiles, but I get a segmentation fault at runtime when executing ./main.

What is the correct way to differentiate a virtual function in this setup?

I attach the code with a Makefile.

Thanks in advance for any guidance.

Classes:

class Tensor
{
public:
    struct Parameters
    {
        // Inputs
        const double *low_values{nullptr};
        const double *high_values{nullptr};
        const double a1{1.23};
        const double a2{4.56};
    };

    // Constructor
    Tensor() {}
    virtual ~Tensor() = default;

    void evaluate(Parameters &params);
    virtual void evaluate_values(const double *low_values, const double *high_values, int length, Parameters *params);
};

class CustomTensorA : public Tensor
{
    private:
        double multiplier;
    
    public:
        CustomTensorA(double mult);
        using Tensor::evaluate_values;
        void evaluate_values(
            const double *low_values,
            const double *high_values,
            int length,
            Parameters *params) override;
};

class CustomTensorB : public Tensor
{
    private:
        double offset;
    
    public:
        CustomTensorB(double off);
        using Tensor::evaluate_values;
        void evaluate_values(
            const double *low_values,
            const double *high_values,
            int length,
            Parameters *params) override;
};

Free function and autodiff function:

// Free function for Enzyme
void _free_evaluate(Tensor *model,
                      const double *low_values,
                      const double *high_values,
                      int length,
                      Tensor::Parameters *args)
{
    model->evaluate_values(low_values, high_values, length, args);
}

int enzyme_dup;
int enzyme_const;
int enzyme_dupnoneed;
void __enzyme_autodiff(void *, ...);

void Tensor::evaluate(Parameters &params)
{
    // For demonstration, assume a fixed length
    const int length = 5;

    // Original call to the virtual function
    //evaluate_values(params.low_values, params.high_values, length, &params);

    // Call the free function, which works like the virtual function
    //_free_evaluate(this, params.low_values, params.high_values, length, &params);

    double *d_low_values = new double[length];
    double *d_high_values = new double[length];

    // Allocate raw memory of the same size
    Tensor *d_this = static_cast<Tensor *>(operator new(sizeof(Tensor))); // do I have to initialize it? Not sure.

    __enzyme_autodiff((void *)_free_evaluate,
                      //enzyme_const, this,
                      enzyme_dup, this, d_this,
                      enzyme_dup, params.low_values, d_low_values,
                      enzyme_dup, params.high_values, d_high_values,
                      enzyme_const, length,
                      enzyme_const, &params);

    printf("[Enzyme] d_low_values[0]: %f\n", d_low_values[0]);
    printf("[Enzyme] d_high_values[0]: %f\n", d_high_values[0]);

    delete[] d_low_values;
    delete[] d_high_values;
}

Source code:

test_virtual_functions.tar.gz

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions